// Package mediaproxy fetches a remote HTTP/HTTPS resource so it can be // re-served from the smesh origin with COEP-compatible CORP headers. // // The relay's epoll loop is single-threaded; Fetch blocks the loop for the // duration of the upstream call. Acceptable for a personal-scale relay; if // concurrent proxy load becomes a problem this should move to a spawned // worker domain that returns the response over a channel. package mediaproxy import ( "bufio" "bytes" "crypto/tls" "fmt" "io" "net" "net/url" "strconv" "time" "smesh.lol/pkg/nostr/ws" ) const ( DefaultMaxBytes int64 = 32 * 1024 * 1024 DefaultTimeout time.Duration = 8 * time.Second maxRedirects = 5 ) // Fetch performs a GET on rawURL with redirect following (max 5 hops). // Returns upstream status, response headers (lowercased keys), body. // Only http and https schemes accepted. func Fetch(rawURL string, maxBytes int64) (int, map[string]string, []byte, error) { if maxBytes <= 0 { maxBytes = DefaultMaxBytes } current := rawURL for hop := 0; hop < maxRedirects; hop++ { status, headers, body, err := fetchOnce(current, maxBytes) if err != nil { return 0, nil, nil, err } if status >= 300 && status < 400 { loc := headers["location"] if loc == "" { return status, headers, body, nil } next, err := resolveRedirect(current, loc) if err != nil { return 0, nil, nil, err } current = next continue } return status, headers, body, nil } return 0, nil, nil, fmt.Errorf("too many redirects") } func fetchOnce(rawURL string, maxBytes int64) (int, map[string]string, []byte, error) { u, err := url.Parse(rawURL) if err != nil { return 0, nil, nil, fmt.Errorf("parse: %w", err) } useTLS := false switch u.Scheme { case "https": useTLS = true case "http": default: return 0, nil, nil, fmt.Errorf("scheme %q not allowed", u.Scheme) } host := u.Hostname() port := u.Port() if port == "" { if useTLS { port = "443" } else { port = "80" } } ip := host if net.ParseIP(host) == nil { ip, err = ws.ResolveHost(host) if err != nil { return 0, nil, nil, fmt.Errorf("resolve %s: %w", host, err) } } conn, err := net.DialTimeout("tcp", net.JoinHostPort(ip, port), DefaultTimeout) if err != nil { return 0, nil, nil, fmt.Errorf("dial: %w", err) } defer conn.Close() deadline := time.Now().Add(DefaultTimeout) conn.SetDeadline(deadline) if useTLS { tlsConn := tls.Client(conn, &tls.Config{ServerName: []byte(host)}) if err := tlsConn.Handshake(); err != nil { return 0, nil, nil, fmt.Errorf("tls: %w", err) } conn = tlsConn } path := u.RequestURI() if path == "" { path = "/" } req := "GET " | path | " HTTP/1.1\r\n" | "Host: " | host | "\r\n" | "User-Agent: smesh-mediaproxy/1\r\n" | "Accept: image/*,video/*,*/*;q=0.5\r\n" | "Accept-Encoding: identity\r\n" | "Connection: close\r\n" | "\r\n" if _, err := conn.Write([]byte(req)); err != nil { return 0, nil, nil, fmt.Errorf("write: %w", err) } br := bufio.NewReaderSize(conn, 32768) statusLine, err := br.ReadString('\n') if err != nil { return 0, nil, nil, fmt.Errorf("read status: %w", err) } status, err := parseStatus(statusLine) if err != nil { return 0, nil, nil, err } headers := map[string]string{} for { line, err := br.ReadString('\n') if err != nil { return 0, nil, nil, fmt.Errorf("read header: %w", err) } trimmed := bytes.TrimRight(line, "\r\n") if len(trimmed) == 0 { break } col := bytes.IndexByte(trimmed, ':') if col < 0 { continue } k := string(bytes.ToLower(bytes.TrimSpace(trimmed[:col]))) v := string(bytes.TrimSpace(trimmed[col+1:])) headers[k] = v } // Redirects don't have meaningful bodies; return early. if status >= 300 && status < 400 { return status, headers, nil, nil } var body []byte te := headers["transfer-encoding"] if te != "" && bytes.Contains(bytes.ToLower([]byte(te)), "chunked") { body, err = readChunked(br, maxBytes) if err != nil { return 0, nil, nil, err } } else if cl := headers["content-length"]; cl != "" { n, err := strconv.ParseInt(cl, 10, 64) if err != nil || n < 0 { return 0, nil, nil, fmt.Errorf("bad content-length") } if n > maxBytes { return 0, nil, nil, fmt.Errorf("response too large: %d", n) } body = []byte{:n} if _, err := io.ReadFull(br, body); err != nil { return 0, nil, nil, fmt.Errorf("read body: %w", err) } } else { body, err = readToEOF(br, maxBytes) if err != nil { return 0, nil, nil, err } } return status, headers, body, nil } func parseStatus(line string) (int, error) { trimmed := bytes.TrimSpace(line) sp1 := bytes.IndexByte(trimmed, ' ') if sp1 < 0 { return 0, fmt.Errorf("bad status line: %s", trimmed) } rest := bytes.TrimSpace(trimmed[sp1+1:]) sp2 := bytes.IndexByte(rest, ' ') var statusBytes []byte if sp2 < 0 { statusBytes = rest } else { statusBytes = rest[:sp2] } n, err := strconv.Atoi(string(statusBytes)) if err != nil { return 0, fmt.Errorf("bad status code: %s", statusBytes) } return n, nil } func readToEOF(r io.Reader, max int64) ([]byte, error) { var buf []byte chunk := []byte{:32 * 1024} for { n, err := r.Read(chunk) if n > 0 { if int64(len(buf))+int64(n) > max { return nil, fmt.Errorf("response too large") } buf = append(buf, chunk[:n]...) } if err == io.EOF { return buf, nil } if err != nil { return nil, err } } } func readChunked(br *bufio.Reader, max int64) ([]byte, error) { var buf []byte for { line, err := br.ReadString('\n') if err != nil { return nil, fmt.Errorf("chunked size: %w", err) } sz := bytes.TrimRight(line, "\r\n") if sc := bytes.IndexByte(sz, ';'); sc >= 0 { sz = sz[:sc] } n, err := strconv.ParseInt(string(bytes.TrimSpace(sz)), 16, 64) if err != nil { return nil, fmt.Errorf("chunked size parse: %w", err) } if n == 0 { // Discard trailers up to empty line. for { t, err := br.ReadString('\n') if err != nil { break } if len(bytes.TrimRight(t, "\r\n")) == 0 { break } } return buf, nil } if int64(len(buf))+n > max { return nil, fmt.Errorf("response too large") } chunk := []byte{:n} if _, err := io.ReadFull(br, chunk); err != nil { return nil, fmt.Errorf("chunked body: %w", err) } buf = append(buf, chunk...) if _, err := br.ReadString('\n'); err != nil { return nil, fmt.Errorf("chunked trailer: %w", err) } } } func resolveRedirect(base, loc string) (string, error) { if len(loc) == 0 { return "", fmt.Errorf("empty location") } // Absolute URL with scheme. if len(loc) >= 7 && (string(loc[:7]) == "http://" || (len(loc) >= 8 && string(loc[:8]) == "https://")) { return loc, nil } bu, err := url.Parse(base) if err != nil { return "", err } // Network-path reference: "//host/path" if len(loc) >= 2 && loc[0] == '/' && loc[1] == '/' { return bu.Scheme | ":" | loc, nil } // Absolute path: "/path" if loc[0] == '/' { return bu.Scheme | "://" | bu.Host | loc, nil } // Relative path — resolve against base path. basePath := bu.Path if basePath == "" { basePath = "/" } slash := -1 for i := len(basePath) - 1; i >= 0; i-- { if basePath[i] == '/' { slash = i break } } if slash < 0 { basePath = "/" } else { basePath = basePath[:slash+1] } return bu.Scheme | "://" | bu.Host | basePath | loc, nil }