mediaproxy.mx raw

   1  // Package mediaproxy fetches a remote HTTP/HTTPS resource so it can be
   2  // re-served from the smesh origin with COEP-compatible CORP headers.
   3  //
   4  // The relay's epoll loop is single-threaded; Fetch blocks the loop for the
   5  // duration of the upstream call. Acceptable for a personal-scale relay; if
   6  // concurrent proxy load becomes a problem this should move to a spawned
   7  // worker domain that returns the response over a channel.
   8  package mediaproxy
   9  
  10  import (
  11  	"bufio"
  12  	"bytes"
  13  	"crypto/tls"
  14  	"fmt"
  15  	"io"
  16  	"net"
  17  	"net/url"
  18  	"strconv"
  19  	"time"
  20  
  21  	"smesh.lol/pkg/nostr/ws"
  22  )
  23  
  24  const (
  25  	DefaultMaxBytes int64         = 32 * 1024 * 1024
  26  	DefaultTimeout  time.Duration = 8 * time.Second
  27  	maxRedirects                  = 5
  28  )
  29  
  30  // Fetch performs a GET on rawURL with redirect following (max 5 hops).
  31  // Returns upstream status, response headers (lowercased keys), body.
  32  // Only http and https schemes accepted.
  33  func Fetch(rawURL string, maxBytes int64) (int, map[string]string, []byte, error) {
  34  	if maxBytes <= 0 {
  35  		maxBytes = DefaultMaxBytes
  36  	}
  37  	current := rawURL
  38  	for hop := 0; hop < maxRedirects; hop++ {
  39  		status, headers, body, err := fetchOnce(current, maxBytes)
  40  		if err != nil {
  41  			return 0, nil, nil, err
  42  		}
  43  		if status >= 300 && status < 400 {
  44  			loc := headers["location"]
  45  			if loc == "" {
  46  				return status, headers, body, nil
  47  			}
  48  			next, err := resolveRedirect(current, loc)
  49  			if err != nil {
  50  				return 0, nil, nil, err
  51  			}
  52  			current = next
  53  			continue
  54  		}
  55  		return status, headers, body, nil
  56  	}
  57  	return 0, nil, nil, fmt.Errorf("too many redirects")
  58  }
  59  
  60  func fetchOnce(rawURL string, maxBytes int64) (int, map[string]string, []byte, error) {
  61  	u, err := url.Parse(rawURL)
  62  	if err != nil {
  63  		return 0, nil, nil, fmt.Errorf("parse: %w", err)
  64  	}
  65  	useTLS := false
  66  	switch u.Scheme {
  67  	case "https":
  68  		useTLS = true
  69  	case "http":
  70  	default:
  71  		return 0, nil, nil, fmt.Errorf("scheme %q not allowed", u.Scheme)
  72  	}
  73  	host := u.Hostname()
  74  	port := u.Port()
  75  	if port == "" {
  76  		if useTLS {
  77  			port = "443"
  78  		} else {
  79  			port = "80"
  80  		}
  81  	}
  82  	ip := host
  83  	if net.ParseIP(host) == nil {
  84  		ip, err = ws.ResolveHost(host)
  85  		if err != nil {
  86  			return 0, nil, nil, fmt.Errorf("resolve %s: %w", host, err)
  87  		}
  88  	}
  89  	conn, err := net.DialTimeout("tcp", net.JoinHostPort(ip, port), DefaultTimeout)
  90  	if err != nil {
  91  		return 0, nil, nil, fmt.Errorf("dial: %w", err)
  92  	}
  93  	defer conn.Close()
  94  	deadline := time.Now().Add(DefaultTimeout)
  95  	conn.SetDeadline(deadline)
  96  	if useTLS {
  97  		tlsConn := tls.Client(conn, &tls.Config{ServerName: []byte(host)})
  98  		if err := tlsConn.Handshake(); err != nil {
  99  			return 0, nil, nil, fmt.Errorf("tls: %w", err)
 100  		}
 101  		conn = tlsConn
 102  	}
 103  	path := u.RequestURI()
 104  	if path == "" {
 105  		path = "/"
 106  	}
 107  	req := "GET " | path | " HTTP/1.1\r\n" |
 108  		"Host: " | host | "\r\n" |
 109  		"User-Agent: smesh-mediaproxy/1\r\n" |
 110  		"Accept: image/*,video/*,*/*;q=0.5\r\n" |
 111  		"Accept-Encoding: identity\r\n" |
 112  		"Connection: close\r\n" |
 113  		"\r\n"
 114  	if _, err := conn.Write([]byte(req)); err != nil {
 115  		return 0, nil, nil, fmt.Errorf("write: %w", err)
 116  	}
 117  	br := bufio.NewReaderSize(conn, 32768)
 118  	statusLine, err := br.ReadString('\n')
 119  	if err != nil {
 120  		return 0, nil, nil, fmt.Errorf("read status: %w", err)
 121  	}
 122  	status, err := parseStatus(statusLine)
 123  	if err != nil {
 124  		return 0, nil, nil, err
 125  	}
 126  	headers := map[string]string{}
 127  	for {
 128  		line, err := br.ReadString('\n')
 129  		if err != nil {
 130  			return 0, nil, nil, fmt.Errorf("read header: %w", err)
 131  		}
 132  		trimmed := bytes.TrimRight(line, "\r\n")
 133  		if len(trimmed) == 0 {
 134  			break
 135  		}
 136  		col := bytes.IndexByte(trimmed, ':')
 137  		if col < 0 {
 138  			continue
 139  		}
 140  		k := string(bytes.ToLower(bytes.TrimSpace(trimmed[:col])))
 141  		v := string(bytes.TrimSpace(trimmed[col+1:]))
 142  		headers[k] = v
 143  	}
 144  	// Redirects don't have meaningful bodies; return early.
 145  	if status >= 300 && status < 400 {
 146  		return status, headers, nil, nil
 147  	}
 148  	var body []byte
 149  	te := headers["transfer-encoding"]
 150  	if te != "" && bytes.Contains(bytes.ToLower([]byte(te)), "chunked") {
 151  		body, err = readChunked(br, maxBytes)
 152  		if err != nil {
 153  			return 0, nil, nil, err
 154  		}
 155  	} else if cl := headers["content-length"]; cl != "" {
 156  		n, err := strconv.ParseInt(cl, 10, 64)
 157  		if err != nil || n < 0 {
 158  			return 0, nil, nil, fmt.Errorf("bad content-length")
 159  		}
 160  		if n > maxBytes {
 161  			return 0, nil, nil, fmt.Errorf("response too large: %d", n)
 162  		}
 163  		body = []byte{:n}
 164  		if _, err := io.ReadFull(br, body); err != nil {
 165  			return 0, nil, nil, fmt.Errorf("read body: %w", err)
 166  		}
 167  	} else {
 168  		body, err = readToEOF(br, maxBytes)
 169  		if err != nil {
 170  			return 0, nil, nil, err
 171  		}
 172  	}
 173  	return status, headers, body, nil
 174  }
 175  
 176  func parseStatus(line string) (int, error) {
 177  	trimmed := bytes.TrimSpace(line)
 178  	sp1 := bytes.IndexByte(trimmed, ' ')
 179  	if sp1 < 0 {
 180  		return 0, fmt.Errorf("bad status line: %s", trimmed)
 181  	}
 182  	rest := bytes.TrimSpace(trimmed[sp1+1:])
 183  	sp2 := bytes.IndexByte(rest, ' ')
 184  	var statusBytes []byte
 185  	if sp2 < 0 {
 186  		statusBytes = rest
 187  	} else {
 188  		statusBytes = rest[:sp2]
 189  	}
 190  	n, err := strconv.Atoi(string(statusBytes))
 191  	if err != nil {
 192  		return 0, fmt.Errorf("bad status code: %s", statusBytes)
 193  	}
 194  	return n, nil
 195  }
 196  
 197  func readToEOF(r io.Reader, max int64) ([]byte, error) {
 198  	var buf []byte
 199  	chunk := []byte{:32 * 1024}
 200  	for {
 201  		n, err := r.Read(chunk)
 202  		if n > 0 {
 203  			if int64(len(buf))+int64(n) > max {
 204  				return nil, fmt.Errorf("response too large")
 205  			}
 206  			buf = append(buf, chunk[:n]...)
 207  		}
 208  		if err == io.EOF {
 209  			return buf, nil
 210  		}
 211  		if err != nil {
 212  			return nil, err
 213  		}
 214  	}
 215  }
 216  
 217  func readChunked(br *bufio.Reader, max int64) ([]byte, error) {
 218  	var buf []byte
 219  	for {
 220  		line, err := br.ReadString('\n')
 221  		if err != nil {
 222  			return nil, fmt.Errorf("chunked size: %w", err)
 223  		}
 224  		sz := bytes.TrimRight(line, "\r\n")
 225  		if sc := bytes.IndexByte(sz, ';'); sc >= 0 {
 226  			sz = sz[:sc]
 227  		}
 228  		n, err := strconv.ParseInt(string(bytes.TrimSpace(sz)), 16, 64)
 229  		if err != nil {
 230  			return nil, fmt.Errorf("chunked size parse: %w", err)
 231  		}
 232  		if n == 0 {
 233  			// Discard trailers up to empty line.
 234  			for {
 235  				t, err := br.ReadString('\n')
 236  				if err != nil {
 237  					break
 238  				}
 239  				if len(bytes.TrimRight(t, "\r\n")) == 0 {
 240  					break
 241  				}
 242  			}
 243  			return buf, nil
 244  		}
 245  		if int64(len(buf))+n > max {
 246  			return nil, fmt.Errorf("response too large")
 247  		}
 248  		chunk := []byte{:n}
 249  		if _, err := io.ReadFull(br, chunk); err != nil {
 250  			return nil, fmt.Errorf("chunked body: %w", err)
 251  		}
 252  		buf = append(buf, chunk...)
 253  		if _, err := br.ReadString('\n'); err != nil {
 254  			return nil, fmt.Errorf("chunked trailer: %w", err)
 255  		}
 256  	}
 257  }
 258  
 259  func resolveRedirect(base, loc string) (string, error) {
 260  	if len(loc) == 0 {
 261  		return "", fmt.Errorf("empty location")
 262  	}
 263  	// Absolute URL with scheme.
 264  	if len(loc) >= 7 && (string(loc[:7]) == "http://" ||
 265  		(len(loc) >= 8 && string(loc[:8]) == "https://")) {
 266  		return loc, nil
 267  	}
 268  	bu, err := url.Parse(base)
 269  	if err != nil {
 270  		return "", err
 271  	}
 272  	// Network-path reference: "//host/path"
 273  	if len(loc) >= 2 && loc[0] == '/' && loc[1] == '/' {
 274  		return bu.Scheme | ":" | loc, nil
 275  	}
 276  	// Absolute path: "/path"
 277  	if loc[0] == '/' {
 278  		return bu.Scheme | "://" | bu.Host | loc, nil
 279  	}
 280  	// Relative path — resolve against base path.
 281  	basePath := bu.Path
 282  	if basePath == "" {
 283  		basePath = "/"
 284  	}
 285  	slash := -1
 286  	for i := len(basePath) - 1; i >= 0; i-- {
 287  		if basePath[i] == '/' {
 288  			slash = i
 289  			break
 290  		}
 291  	}
 292  	if slash < 0 {
 293  		basePath = "/"
 294  	} else {
 295  		basePath = basePath[:slash+1]
 296  	}
 297  	return bu.Scheme | "://" | bu.Host | basePath | loc, nil
 298  }
 299