spirvcross.go raw

   1  // SPDX-License-Identifier: Unlicense OR MIT
   2  
   3  package main
   4  
   5  import (
   6  	"encoding/json"
   7  	"fmt"
   8  	"os/exec"
   9  	"path/filepath"
  10  	"sort"
  11  	"strings"
  12  
  13  	"github.com/p9c/p9/pkg/gel/gio/gpu/internal/driver"
  14  )
  15  
  16  // Metadata contains reflection data about a shader.
  17  type Metadata struct {
  18  	Uniforms driver.UniformsReflection
  19  	Inputs   []driver.InputLocation
  20  	Textures []driver.TextureBinding
  21  }
  22  
  23  // SPIRVCross cross-compiles spirv shaders to es, hlsl and others.
  24  type SPIRVCross struct {
  25  	Bin     string
  26  	WorkDir WorkDir
  27  }
  28  
  29  func NewSPIRVCross() *SPIRVCross { return &SPIRVCross{Bin: "spirv-cross"} }
  30  
  31  // Convert converts compute shader from spirv format to a target format.
  32  func (spirv *SPIRVCross) Convert(path, variant string, shader []byte, target, version string) (string, error) {
  33  	base := spirv.WorkDir.Path(filepath.Base(path), variant)
  34  
  35  	if err := spirv.WorkDir.WriteFile(base, shader); err != nil {
  36  		return "", fmt.Errorf("unable to write shader to disk: %w", err)
  37  	}
  38  
  39  	var cmd *exec.Cmd
  40  	switch target {
  41  	case "glsl":
  42  		cmd = exec.Command(spirv.Bin,
  43  			"--no-es",
  44  			"--version", version,
  45  		)
  46  	case "es":
  47  		cmd = exec.Command(spirv.Bin,
  48  			"--es",
  49  			"--version", version,
  50  		)
  51  	case "hlsl":
  52  		cmd = exec.Command(spirv.Bin,
  53  			"--hlsl",
  54  			"--shader-model", version,
  55  		)
  56  	default:
  57  		return "", fmt.Errorf("unknown target %q", target)
  58  	}
  59  	cmd.Args = append(cmd.Args, "--no-420pack-extension", base)
  60  
  61  	out, err := cmd.CombinedOutput()
  62  	if err != nil {
  63  		return "", fmt.Errorf("%s\nfailed to run %v: %w", out, cmd.Args, err)
  64  	}
  65  	s := string(out)
  66  	if target != "hlsl" {
  67  		// Strip Windows \r in line endings.
  68  		s = unixLineEnding(s)
  69  	}
  70  
  71  	return s, nil
  72  }
  73  
  74  // Metadata extracts metadata for a SPIR-V shader.
  75  func (spirv *SPIRVCross) Metadata(path, variant string, shader []byte) (Metadata, error) {
  76  	base := spirv.WorkDir.Path(filepath.Base(path), variant)
  77  
  78  	if err := spirv.WorkDir.WriteFile(base, shader); err != nil {
  79  		return Metadata{}, fmt.Errorf("unable to write shader to disk: %w", err)
  80  	}
  81  
  82  	cmd := exec.Command(spirv.Bin,
  83  		base,
  84  		"--reflect",
  85  	)
  86  
  87  	out, err := cmd.Output()
  88  	if err != nil {
  89  		return Metadata{}, fmt.Errorf("failed to run %v: %w", cmd.Args, err)
  90  	}
  91  
  92  	meta, err := parseMetadata(out)
  93  	if err != nil {
  94  		return Metadata{}, fmt.Errorf("%s\nfailed to parse metadata: %w", out, err)
  95  	}
  96  
  97  	return meta, nil
  98  }
  99  
 100  func parseMetadata(data []byte) (Metadata, error) {
 101  	var reflect struct {
 102  		Types map[string]struct {
 103  			Name    string `json:"name"`
 104  			Members []struct {
 105  				Name   string `json:"name"`
 106  				Type   string `json:"type"`
 107  				Offset int    `json:"offset"`
 108  			} `json:"members"`
 109  		} `json:"types"`
 110  		Inputs []struct {
 111  			Name     string `json:"name"`
 112  			Type     string `json:"type"`
 113  			Location int    `json:"location"`
 114  		} `json:"inputs"`
 115  		Textures []struct {
 116  			Name    string `json:"name"`
 117  			Type    string `json:"type"`
 118  			Set     int    `json:"set"`
 119  			Binding int    `json:"binding"`
 120  		} `json:"textures"`
 121  		UBOs []struct {
 122  			Name      string `json:"name"`
 123  			Type      string `json:"type"`
 124  			BlockSize int    `json:"block_size"`
 125  			Set       int    `json:"set"`
 126  			Binding   int    `json:"binding"`
 127  		} `json:"ubos"`
 128  	}
 129  	if err := json.Unmarshal(data, &reflect); err != nil {
 130  		return Metadata{}, fmt.Errorf("failed to parse reflection data: %w", err)
 131  	}
 132  
 133  	var m Metadata
 134  
 135  	for _, input := range reflect.Inputs {
 136  		dataType, dataSize, err := parseDataType(input.Type)
 137  		if err != nil {
 138  			return Metadata{}, fmt.Errorf("parseReflection: %v", err)
 139  		}
 140  		m.Inputs = append(m.Inputs, driver.InputLocation{
 141  			Name:          input.Name,
 142  			Location:      input.Location,
 143  			Semantic:      "TEXCOORD",
 144  			SemanticIndex: input.Location,
 145  			Type:          dataType,
 146  			Size:          dataSize,
 147  		})
 148  	}
 149  
 150  	sort.Slice(m.Inputs, func(i, j int) bool {
 151  		return m.Inputs[i].Location < m.Inputs[j].Location
 152  	})
 153  
 154  	blockOffset := 0
 155  	for _, block := range reflect.UBOs {
 156  		m.Uniforms.Blocks = append(m.Uniforms.Blocks, driver.UniformBlock{
 157  			Name:    block.Name,
 158  			Binding: block.Binding,
 159  		})
 160  		t := reflect.Types[block.Type]
 161  		// By convention uniform block variables are named by prepending an underscore
 162  		// and converting to lowercase.
 163  		blockVar := "_" + strings.ToLower(block.Name)
 164  		for _, member := range t.Members {
 165  			dataType, size, err := parseDataType(member.Type)
 166  			if err != nil {
 167  				return Metadata{}, fmt.Errorf("failed to parse reflection data: %v", err)
 168  			}
 169  			m.Uniforms.Locations = append(m.Uniforms.Locations, driver.UniformLocation{
 170  				Name:   fmt.Sprintf("%s.%s", blockVar, member.Name),
 171  				Type:   dataType,
 172  				Size:   size,
 173  				Offset: blockOffset + member.Offset,
 174  			})
 175  		}
 176  		blockOffset += block.BlockSize
 177  	}
 178  	m.Uniforms.Size = blockOffset
 179  
 180  	for _, texture := range reflect.Textures {
 181  		m.Textures = append(m.Textures, driver.TextureBinding{
 182  			Name:    texture.Name,
 183  			Binding: texture.Binding,
 184  		})
 185  	}
 186  
 187  	//return m, fmt.Errorf("not yet!: %+v", reflect)
 188  	return m, nil
 189  }
 190  
 191  func parseDataType(t string) (driver.DataType, int, error) {
 192  	switch t {
 193  	case "float":
 194  		return driver.DataTypeFloat, 1, nil
 195  	case "vec2":
 196  		return driver.DataTypeFloat, 2, nil
 197  	case "vec3":
 198  		return driver.DataTypeFloat, 3, nil
 199  	case "vec4":
 200  		return driver.DataTypeFloat, 4, nil
 201  	case "int":
 202  		return driver.DataTypeInt, 1, nil
 203  	case "int2":
 204  		return driver.DataTypeInt, 2, nil
 205  	case "int3":
 206  		return driver.DataTypeInt, 3, nil
 207  	case "int4":
 208  		return driver.DataTypeInt, 4, nil
 209  	default:
 210  		return 0, 0, fmt.Errorf("unsupported input data type: %s", t)
 211  	}
 212  }
 213