1 package websocket
2 3 import (
4 "context"
5 "fmt"
6 "io"
7 "math"
8 "net"
9 "sync/atomic"
10 "time"
11 )
12 13 // NetConn converts a *websocket.Conn into a net.Conn.
14 //
15 // It's for tunneling arbitrary protocols over WebSockets.
16 // Few users of the library will need this but it's tricky to implement
17 // correctly and so provided in the library.
18 // See https://github.com/nhooyr/websocket/issues/100.
19 //
20 // Every Write to the net.Conn will correspond to a message write of
21 // the given type on *websocket.Conn.
22 //
23 // The passed ctx bounds the lifetime of the net.Conn. If cancelled,
24 // all reads and writes on the net.Conn will be cancelled.
25 //
26 // If a message is read that is not of the correct type, the connection
27 // will be closed with StatusUnsupportedData and an error will be returned.
28 //
29 // Close will close the *websocket.Conn with StatusNormalClosure.
30 //
31 // When a deadline is hit and there is an active read or write goroutine, the
32 // connection will be closed. This is different from most net.Conn implementations
33 // where only the reading/writing goroutines are interrupted but the connection
34 // is kept alive.
35 //
36 // The Addr methods will return the real addresses for connections obtained
37 // from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr
38 // will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for
39 // String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the
40 // full net.Conn to us.
41 //
42 // When running as WASM, the Addr methods will always return the mock address described above.
43 //
44 // A received StatusNormalClosure or StatusGoingAway close frame will be translated to
45 // io.EOF when reading.
46 //
47 // Furthermore, the ReadLimit is set to -1 to disable it.
48 func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
49 c.SetReadLimit(-1)
50 51 nc := &netConn{
52 c: c,
53 msgType: msgType,
54 readMu: newMu(c),
55 writeMu: newMu(c),
56 }
57 58 nc.writeCtx, nc.writeCancel = context.WithCancel(ctx)
59 nc.readCtx, nc.readCancel = context.WithCancel(ctx)
60 61 nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
62 if !nc.writeMu.tryLock() {
63 // If the lock cannot be acquired, then there is an
64 // active write goroutine and so we should cancel the context.
65 nc.writeCancel()
66 return
67 }
68 defer nc.writeMu.unlock()
69 70 // Prevents future writes from writing until the deadline is reset.
71 atomic.StoreInt64(&nc.writeExpired, 1)
72 })
73 if !nc.writeTimer.Stop() {
74 <-nc.writeTimer.C
75 }
76 77 nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
78 if !nc.readMu.tryLock() {
79 // If the lock cannot be acquired, then there is an
80 // active read goroutine and so we should cancel the context.
81 nc.readCancel()
82 return
83 }
84 defer nc.readMu.unlock()
85 86 // Prevents future reads from reading until the deadline is reset.
87 atomic.StoreInt64(&nc.readExpired, 1)
88 })
89 if !nc.readTimer.Stop() {
90 <-nc.readTimer.C
91 }
92 93 return nc
94 }
95 96 type netConn struct {
97 // These must be first to be aligned on 32 bit platforms.
98 // https://github.com/nhooyr/websocket/pull/438
99 readExpired int64
100 writeExpired int64
101 102 c *Conn
103 msgType MessageType
104 105 writeTimer *time.Timer
106 writeMu *mu
107 writeCtx context.Context
108 writeCancel context.CancelFunc
109 110 readTimer *time.Timer
111 readMu *mu
112 readCtx context.Context
113 readCancel context.CancelFunc
114 readEOFed bool
115 reader io.Reader
116 }
117 118 var _ net.Conn = &netConn{}
119 120 func (nc *netConn) Close() error {
121 nc.writeTimer.Stop()
122 nc.writeCancel()
123 nc.readTimer.Stop()
124 nc.readCancel()
125 return nc.c.Close(StatusNormalClosure, "")
126 }
127 128 func (nc *netConn) Write(p []byte) (int, error) {
129 nc.writeMu.forceLock()
130 defer nc.writeMu.unlock()
131 132 if atomic.LoadInt64(&nc.writeExpired) == 1 {
133 return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
134 }
135 136 err := nc.c.Write(nc.writeCtx, nc.msgType, p)
137 if err != nil {
138 return 0, err
139 }
140 return len(p), nil
141 }
142 143 func (nc *netConn) Read(p []byte) (int, error) {
144 nc.readMu.forceLock()
145 defer nc.readMu.unlock()
146 147 for {
148 n, err := nc.read(p)
149 if err != nil {
150 return n, err
151 }
152 if n == 0 {
153 continue
154 }
155 return n, nil
156 }
157 }
158 159 func (nc *netConn) read(p []byte) (int, error) {
160 if atomic.LoadInt64(&nc.readExpired) == 1 {
161 return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
162 }
163 164 if nc.readEOFed {
165 return 0, io.EOF
166 }
167 168 if nc.reader == nil {
169 typ, r, err := nc.c.Reader(nc.readCtx)
170 if err != nil {
171 switch CloseStatus(err) {
172 case StatusNormalClosure, StatusGoingAway:
173 nc.readEOFed = true
174 return 0, io.EOF
175 }
176 return 0, err
177 }
178 if typ != nc.msgType {
179 err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
180 nc.c.Close(StatusUnsupportedData, err.Error())
181 return 0, err
182 }
183 nc.reader = r
184 }
185 186 n, err := nc.reader.Read(p)
187 if err == io.EOF {
188 nc.reader = nil
189 err = nil
190 }
191 return n, err
192 }
193 194 type websocketAddr struct {
195 }
196 197 func (a websocketAddr) Network() string {
198 return "websocket"
199 }
200 201 func (a websocketAddr) String() string {
202 return "websocket/unknown-addr"
203 }
204 205 func (nc *netConn) SetDeadline(t time.Time) error {
206 nc.SetWriteDeadline(t)
207 nc.SetReadDeadline(t)
208 return nil
209 }
210 211 func (nc *netConn) SetWriteDeadline(t time.Time) error {
212 atomic.StoreInt64(&nc.writeExpired, 0)
213 if t.IsZero() {
214 nc.writeTimer.Stop()
215 } else {
216 dur := time.Until(t)
217 if dur <= 0 {
218 dur = 1
219 }
220 nc.writeTimer.Reset(dur)
221 }
222 return nil
223 }
224 225 func (nc *netConn) SetReadDeadline(t time.Time) error {
226 atomic.StoreInt64(&nc.readExpired, 0)
227 if t.IsZero() {
228 nc.readTimer.Stop()
229 } else {
230 dur := time.Until(t)
231 if dur <= 0 {
232 dur = 1
233 }
234 nc.readTimer.Reset(dur)
235 }
236 return nil
237 }
238