register.mx raw

   1  // Copyright 2010 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package zip
   6  
   7  import (
   8  	"compress/flate"
   9  	"errors"
  10  	"io"
  11  	"sync"
  12  )
  13  
  14  // A Compressor returns a new compressing writer, writing to w.
  15  // The WriteCloser's Close method must be used to flush pending data to w.
  16  // The Compressor itself must be safe to invoke from multiple goroutines
  17  // simultaneously, but each returned writer will be used only by
  18  // one goroutine at a time.
  19  type Compressor func(w io.Writer) (io.WriteCloser, error)
  20  
  21  // A Decompressor returns a new decompressing reader, reading from r.
  22  // The [io.ReadCloser]'s Close method must be used to release associated resources.
  23  // The Decompressor itself must be safe to invoke from multiple goroutines
  24  // simultaneously, but each returned reader will be used only by
  25  // one goroutine at a time.
  26  type Decompressor func(r io.Reader) io.ReadCloser
  27  
  28  var flateWriterPool sync.Pool
  29  
  30  func newFlateWriter(w io.Writer) io.WriteCloser {
  31  	fw, ok := flateWriterPool.Get().(*flate.Writer)
  32  	if ok {
  33  		fw.Reset(w)
  34  	} else {
  35  		fw, _ = flate.NewWriter(w, 5)
  36  	}
  37  	return &pooledFlateWriter{fw: fw}
  38  }
  39  
  40  type pooledFlateWriter struct {
  41  	mu sync.Mutex // guards Close and Write
  42  	fw *flate.Writer
  43  }
  44  
  45  func (w *pooledFlateWriter) Write(p []byte) (n int, err error) {
  46  	w.mu.Lock()
  47  	defer w.mu.Unlock()
  48  	if w.fw == nil {
  49  		return 0, errors.New("Write after Close")
  50  	}
  51  	return w.fw.Write(p)
  52  }
  53  
  54  func (w *pooledFlateWriter) Close() error {
  55  	w.mu.Lock()
  56  	defer w.mu.Unlock()
  57  	var err error
  58  	if w.fw != nil {
  59  		err = w.fw.Close()
  60  		flateWriterPool.Put(w.fw)
  61  		w.fw = nil
  62  	}
  63  	return err
  64  }
  65  
  66  var flateReaderPool sync.Pool
  67  
  68  func newFlateReader(r io.Reader) io.ReadCloser {
  69  	fr, ok := flateReaderPool.Get().(io.ReadCloser)
  70  	if ok {
  71  		fr.(flate.Resetter).Reset(r, nil)
  72  	} else {
  73  		fr = flate.NewReader(r)
  74  	}
  75  	return &pooledFlateReader{fr: fr}
  76  }
  77  
  78  type pooledFlateReader struct {
  79  	mu sync.Mutex // guards Close and Read
  80  	fr io.ReadCloser
  81  }
  82  
  83  func (r *pooledFlateReader) Read(p []byte) (n int, err error) {
  84  	r.mu.Lock()
  85  	defer r.mu.Unlock()
  86  	if r.fr == nil {
  87  		return 0, errors.New("Read after Close")
  88  	}
  89  	return r.fr.Read(p)
  90  }
  91  
  92  func (r *pooledFlateReader) Close() error {
  93  	r.mu.Lock()
  94  	defer r.mu.Unlock()
  95  	var err error
  96  	if r.fr != nil {
  97  		err = r.fr.Close()
  98  		flateReaderPool.Put(r.fr)
  99  		r.fr = nil
 100  	}
 101  	return err
 102  }
 103  
 104  var (
 105  	compressors   sync.Map // map[uint16]Compressor
 106  	decompressors sync.Map // map[uint16]Decompressor
 107  )
 108  
 109  func init() {
 110  	compressors.Store(Store, Compressor(func(w io.Writer) (io.WriteCloser, error) { return &nopCloser{w}, nil }))
 111  	compressors.Store(Deflate, Compressor(func(w io.Writer) (io.WriteCloser, error) { return newFlateWriter(w), nil }))
 112  
 113  	decompressors.Store(Store, Decompressor(io.NopCloser))
 114  	decompressors.Store(Deflate, Decompressor(newFlateReader))
 115  }
 116  
 117  // RegisterDecompressor allows custom decompressors for a specified method ID.
 118  // The common methods [Store] and [Deflate] are built in.
 119  func RegisterDecompressor(method uint16, dcomp Decompressor) {
 120  	if _, dup := decompressors.LoadOrStore(method, dcomp); dup {
 121  		panic("decompressor already registered")
 122  	}
 123  }
 124  
 125  // RegisterCompressor registers custom compressors for a specified method ID.
 126  // The common methods [Store] and [Deflate] are built in.
 127  func RegisterCompressor(method uint16, comp Compressor) {
 128  	if _, dup := compressors.LoadOrStore(method, comp); dup {
 129  		panic("compressor already registered")
 130  	}
 131  }
 132  
 133  func compressor(method uint16) Compressor {
 134  	ci, ok := compressors.Load(method)
 135  	if !ok {
 136  		return nil
 137  	}
 138  	return ci.(Compressor)
 139  }
 140  
 141  func decompressor(method uint16) Decompressor {
 142  	di, ok := decompressors.Load(method)
 143  	if !ok {
 144  		return nil
 145  	}
 146  	return di.(Decompressor)
 147  }
 148