You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
synctv/server/handlers/proxy/slice.go

365 lines
9.8 KiB
Go

package proxy
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/zijiren233/ksync"
"github.com/zijiren233/stream"
)
var mu = ksync.DefaultKmutex()
// Proxy defines the interface for proxy implementations
type Proxy interface {
io.ReadSeeker
ContentTotalLength() (int64, error)
ContentType() (string, error)
}
// Headers defines the interface for accessing response headers
type Headers interface {
Headers() http.Header
}
// SliceCacheProxy implements caching of content slices
type SliceCacheProxy struct {
r Proxy
cache Cache
key string
sliceSize int64
}
// NewSliceCacheProxy creates a new SliceCacheProxy instance
func NewSliceCacheProxy(key string, sliceSize int64, r Proxy, cache Cache) *SliceCacheProxy {
return &SliceCacheProxy{
key: key,
sliceSize: sliceSize,
r: r,
cache: cache,
}
}
func cacheKey(key string, offset int64, sliceSize int64) string {
hash := sha256.Sum256(stream.StringToBytes(key))
return fmt.Sprintf("%s-%d-%d", hex.EncodeToString(hash[:]), sliceSize, offset)
}
func cachePrefix(key string, sliceSize int64) string {
hash := sha256.Sum256(stream.StringToBytes(key))
return fmt.Sprintf("%s-%d", hex.EncodeToString(hash[:]), sliceSize)
}
func alignedOffset(offset, sliceSize int64) int64 {
return (offset / sliceSize) * sliceSize
}
func fmtContentRange(start, end, total int64) string {
if total == -1 && end == -1 {
return "bytes */*"
}
totalStr := "*"
if total >= 0 {
totalStr = strconv.FormatInt(total, 10)
}
if end == -1 {
if total >= 0 {
end = total - 1
}
return fmt.Sprintf("bytes %d-%d/%s", start, end, totalStr)
}
return fmt.Sprintf("bytes %d-%d/%s", start, end, totalStr)
}
func contentLength(start, end, total int64) int64 {
if total == -1 && end == -1 {
return -1
}
if end == -1 {
return total - start
}
if end >= total && total != -1 {
return total - start
}
return end - start + 1
}
func fmtContentLength(start, end, total int64) string {
length := contentLength(start, end, total)
if length == -1 {
return ""
}
return strconv.FormatInt(length, 10)
}
// ServeHTTP implements http.Handler interface
func (c *SliceCacheProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_ = c.Proxy(w, r)
}
func (c *SliceCacheProxy) Proxy(w http.ResponseWriter, r *http.Request) error {
byteRange, err := ParseByteRange(r.Header.Get("Range"))
if err != nil {
http.Error(w, fmt.Sprintf("Failed to parse Range header: %v", err), http.StatusBadRequest)
return fmt.Errorf("failed to parse Range header: %w", err)
}
alignedOffset := alignedOffset(byteRange.Start, c.sliceSize)
cacheItem, cached, err := c.getCacheItem(alignedOffset)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get cache item: %v", err), http.StatusInternalServerError)
return fmt.Errorf("failed to get cache item: %w", err)
}
c.setResponseHeaders(w, byteRange, cacheItem, cached, r.Header.Get("Range") != "")
if err := c.writeResponse(w, byteRange, alignedOffset, cacheItem); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
return nil
}
const cacheStatusHeader = "X-Cache-Status"
func (c *SliceCacheProxy) setResponseHeaders(w http.ResponseWriter, byteRange *ByteRange, cacheItem *CacheItem, cached bool, isRangeRequest bool) {
// Copy headers excluding special ones
for k, v := range cacheItem.Metadata.Headers {
switch k {
case "Content-Type", "Content-Length", "Content-Range", "Accept-Ranges":
continue
default:
w.Header()[k] = v
}
}
if cached {
w.Header().Set(cacheStatusHeader, "HIT")
} else {
w.Header().Set(cacheStatusHeader, "MISS")
}
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Content-Length", fmtContentLength(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength))
w.Header().Set("Content-Type", cacheItem.Metadata.ContentType)
if isRangeRequest {
w.Header().Set("Content-Range", fmtContentRange(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength))
w.WriteHeader(http.StatusPartialContent)
} else {
w.WriteHeader(http.StatusOK)
}
}
func (c *SliceCacheProxy) writeResponse(w http.ResponseWriter, byteRange *ByteRange, alignedOffset int64, cacheItem *CacheItem) error {
sliceOffset := byteRange.Start - alignedOffset
if sliceOffset < 0 {
return fmt.Errorf("slice offset cannot be negative, got: %d", sliceOffset)
}
remainingLength := contentLength(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength)
if remainingLength == 0 {
return nil
}
// Write initial slice
if remainingLength > 0 {
n := int64(len(cacheItem.Data)) - sliceOffset
if n > remainingLength {
n = remainingLength
}
if n > 0 {
if _, err := w.Write(cacheItem.Data[sliceOffset : sliceOffset+n]); err != nil {
return fmt.Errorf("failed to write initial data slice: %w", err)
}
remainingLength -= n
}
}
// Write subsequent slices
currentOffset := alignedOffset + c.sliceSize
for remainingLength > 0 {
cacheItem, _, err := c.getCacheItem(currentOffset)
if err != nil {
return fmt.Errorf("failed to get cache item at offset %d: %w", currentOffset, err)
}
n := int64(len(cacheItem.Data))
if n > remainingLength {
n = remainingLength
}
if n > 0 {
if _, err := w.Write(cacheItem.Data[:n]); err != nil {
return fmt.Errorf("failed to write data slice at offset %d: %w", currentOffset, err)
}
remainingLength -= n
}
currentOffset += c.sliceSize
}
return nil
}
func (c *SliceCacheProxy) getCacheItem(alignedOffset int64) (*CacheItem, bool, error) {
if alignedOffset < 0 {
return nil, false, fmt.Errorf("cache item offset cannot be negative, got: %d", alignedOffset)
}
cacheKey := cacheKey(c.key, alignedOffset, c.sliceSize)
mu.Lock(cacheKey)
defer mu.Unlock(cacheKey)
// Try to get from cache first
slice, ok, err := c.cache.Get(cacheKey)
if err != nil {
return nil, false, fmt.Errorf("failed to get item from cache: %w", err)
}
if ok {
return slice, true, nil
}
// Fetch from source if not in cache
slice, err = c.fetchFromSource(alignedOffset)
if err != nil {
return nil, false, fmt.Errorf("failed to fetch item from source: %w", err)
}
// Store in cache
if err = c.cache.Set(cacheKey, slice); err != nil {
return nil, false, fmt.Errorf("failed to store item in cache: %w", err)
}
return slice, false, nil
}
func (c *SliceCacheProxy) contentTotalLength() (int64, error) {
total, err := c.r.ContentTotalLength()
if err != nil {
return -1, fmt.Errorf("failed to get content total length from source: %w", err)
}
if total == -1 {
return -1, errors.New("source does not support range requests")
}
return total, nil
}
func (c *SliceCacheProxy) fetchFromSource(offset int64) (*CacheItem, error) {
if offset < 0 {
return nil, fmt.Errorf("source offset cannot be negative, got: %d", offset)
}
if _, err := c.r.Seek(offset, io.SeekStart); err != nil {
return nil, fmt.Errorf("failed to seek to offset %d in source: %w", offset, err)
}
var total int64 = -1
buf := make([]byte, c.sliceSize)
n, err := io.ReadFull(c.r, buf)
if err != nil {
if !errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("failed to read %d bytes from source at offset %d: %w", c.sliceSize, offset, err)
}
total, err = c.contentTotalLength()
if err != nil {
return nil, fmt.Errorf("failed to get content total length from source: %w", err)
}
if total != offset+int64(n) {
return nil, fmt.Errorf("source content total length mismatch, got: %d, expected: %d, %w", total, offset+int64(n), io.ErrUnexpectedEOF)
}
}
if total == -1 {
total, err = c.contentTotalLength()
if err != nil {
return nil, fmt.Errorf("failed to get content total length from source: %w", err)
}
}
contentType, err := c.r.ContentType()
if err != nil {
return nil, fmt.Errorf("failed to get content type from source: %w", err)
}
var headers http.Header
if h, ok := c.r.(Headers); ok {
headers = h.Headers().Clone()
} else {
headers = make(http.Header)
}
return &CacheItem{
Metadata: &CacheMetadata{
Headers: headers,
ContentTotalLength: total,
ContentType: contentType,
},
Data: buf[:n],
}, nil
}
// ByteRange represents an HTTP Range header value
type ByteRange struct {
Start int64
End int64
}
// ParseByteRange parses a Range header value in the format:
// bytes=<start>-<end>
// where end is optional
func ParseByteRange(r string) (*ByteRange, error) {
if r == "" {
return &ByteRange{Start: 0, End: -1}, nil
}
if !strings.HasPrefix(r, "bytes=") {
return nil, fmt.Errorf("range header must start with 'bytes=', got: %s", r)
}
if strings.Contains(r, ",") {
return nil, fmt.Errorf("not support multi-range, got: %s", r)
}
r = strings.TrimPrefix(r, "bytes=")
parts := strings.Split(r, "-")
if len(parts) != 2 {
return nil, fmt.Errorf("range header must contain exactly one hyphen (-) separator, got: %s", r)
}
parts[0] = strings.TrimSpace(parts[0])
parts[1] = strings.TrimSpace(parts[1])
if parts[0] == "" && parts[1] == "" {
return nil, fmt.Errorf("range header cannot have empty start and end values: %s", r)
}
var start, end int64 = 0, -1
var err error
if parts[0] != "" {
start, err = strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse range start value '%s': %w", parts[0], err)
}
if start < 0 {
return nil, fmt.Errorf("range start value must be non-negative, got: %d", start)
}
}
if parts[1] != "" {
end, err = strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse range end value '%s': %w", parts[1], err)
}
if end < 0 {
return nil, fmt.Errorf("range end value must be non-negative, got: %d", end)
}
if start > end {
return nil, fmt.Errorf("range start value (%d) cannot be greater than end value (%d)", start, end)
}
}
return &ByteRange{Start: start, End: end}, nil
}