proxy_worker.mx raw

   1  package wire
   2  
   3  import (
   4  	"bytes"
   5  	"io"
   6  	"runtime"
   7  	"syscall"
   8  
   9  	"smesh.lol/pkg/mediaproxy"
  10  )
  11  
  12  // ProxyWorker is the spawn target for a media proxy domain.
  13  // Uses ChildPipeFd directly (bypassing the broken spawn channel binding)
  14  // to read Codec-encoded ProxyRequest frames and write ProxyResponse frames.
  15  // chanID=0 = requests (in direction), chanID=1 = responses (out direction).
  16  func ProxyWorker(in chan ProxyRequest, out chan ProxyResponse) {
  17  	fd := int32(runtime.ChildPipeFd)
  18  	if fd < 0 {
  19  		return
  20  	}
  21  	for {
  22  		var hdr [6]byte
  23  		if !proxyReadAll(fd, hdr[:]) {
  24  			return
  25  		}
  26  		chanID := uint16(hdr[0])<<8 | uint16(hdr[1])
  27  		l := uint32(hdr[2])<<24 | uint32(hdr[3])<<16 | uint32(hdr[4])<<8 | uint32(hdr[5])
  28  		if chanID != 0 || l == 0 || l > 8<<20 {
  29  			proxyDrain(fd, l)
  30  			continue
  31  		}
  32  		payload := []byte{:l}
  33  		if !proxyReadAll(fd, payload) {
  34  			return
  35  		}
  36  		var req ProxyRequest
  37  		if err := req.DecodeFrom(bytes.NewReader(payload)); err != nil {
  38  			continue
  39  		}
  40  		resp := ProxyResponse{ReqID: req.ReqID}
  41  		maxBytes := int64(req.MaxBytes)
  42  		if maxBytes <= 0 {
  43  			maxBytes = 32 * 1024 * 1024
  44  		}
  45  		status, upstream, body, err := mediaproxy.Fetch(string(req.URL), maxBytes)
  46  		if err != nil {
  47  			resp.Status = -1
  48  			resp.Err = []byte(err.Error())
  49  		} else {
  50  			resp.Status = int32(status)
  51  			if status >= 200 && status < 300 {
  52  				ct := upstream["content-type"]
  53  				if !proxyAllowedCT(ct) {
  54  					resp.Status = 415
  55  				} else {
  56  					resp.ContentType = []byte(ct)
  57  					resp.Body = body
  58  				}
  59  			}
  60  		}
  61  		var buf bytes.Buffer
  62  		resp.EncodeTo(&buf)
  63  		data := buf.Bytes()
  64  		lresp := uint32(len(data))
  65  		var rhdr [6]byte
  66  		rhdr[0] = 0; rhdr[1] = 1
  67  		rhdr[2] = byte(lresp >> 24); rhdr[3] = byte(lresp >> 16)
  68  		rhdr[4] = byte(lresp >> 8); rhdr[5] = byte(lresp)
  69  		frame := []byte{:0:6 + int(lresp)}
  70  		frame = append(frame, rhdr[:]...)
  71  		frame = append(frame, data...)
  72  		syscall.Write(int(fd), frame)
  73  	}
  74  }
  75  
  76  func proxyAllowedCT(ct string) bool {
  77  	if len(ct) >= 6 && (ct[:6] == "image/" || ct[:6] == "video/") {
  78  		return true
  79  	}
  80  	return ct == "application/octet-stream" || ct == "application/octet-stream; charset=utf-8"
  81  }
  82  
  83  func proxyReadAll(fd int32, buf []byte) bool {
  84  	got := 0
  85  	for got < len(buf) {
  86  		n, err := syscall.Read(int(fd), buf[got:])
  87  		if n <= 0 || err != nil {
  88  			return false
  89  		}
  90  		got += n
  91  	}
  92  	return true
  93  }
  94  
  95  func proxyDrain(fd int32, n uint32) {
  96  	discard := []byte{:512}
  97  	for n > 0 {
  98  		chunk := int(n)
  99  		if chunk > 512 {
 100  			chunk = 512
 101  		}
 102  		nr, err := syscall.Read(int(fd), discard[:chunk])
 103  		if err != nil || nr <= 0 {
 104  			return
 105  		}
 106  		n -= uint32(nr)
 107  	}
 108  }
 109  
 110  type proxyReader struct{ fd int32 }
 111  
 112  func (r *proxyReader) Read(b []byte) (int, error) {
 113  	n, err := syscall.Read(int(r.fd), b)
 114  	if n == 0 && err == nil {
 115  		return 0, io.EOF
 116  	}
 117  	return n, err
 118  }
 119