internal.go raw

   1  /*
   2   *
   3   * Copyright 2024 gRPC authors.
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     http://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  // Package internal contains code that is shared by both reflection package and
  20  // the test package. The packages are split in this way inorder to avoid
  21  // dependency to deprecated package github.com/golang/protobuf.
  22  package internal
  23  
  24  import (
  25  	"io"
  26  	"sort"
  27  
  28  	"google.golang.org/grpc"
  29  	"google.golang.org/grpc/codes"
  30  	"google.golang.org/grpc/status"
  31  	"google.golang.org/protobuf/proto"
  32  	"google.golang.org/protobuf/reflect/protodesc"
  33  	"google.golang.org/protobuf/reflect/protoreflect"
  34  	"google.golang.org/protobuf/reflect/protoregistry"
  35  
  36  	v1reflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1"
  37  	v1reflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1"
  38  	v1alphareflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
  39  	v1alphareflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
  40  )
  41  
  42  // ServiceInfoProvider is an interface used to retrieve metadata about the
  43  // services to expose.
  44  type ServiceInfoProvider interface {
  45  	GetServiceInfo() map[string]grpc.ServiceInfo
  46  }
  47  
  48  // ExtensionResolver is the interface used to query details about extensions.
  49  // This interface is satisfied by protoregistry.GlobalTypes.
  50  type ExtensionResolver interface {
  51  	protoregistry.ExtensionTypeResolver
  52  	RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool)
  53  }
  54  
  55  // ServerReflectionServer is the server API for ServerReflection service.
  56  type ServerReflectionServer struct {
  57  	v1alphareflectiongrpc.UnimplementedServerReflectionServer
  58  	S            ServiceInfoProvider
  59  	DescResolver protodesc.Resolver
  60  	ExtResolver  ExtensionResolver
  61  }
  62  
  63  // FileDescWithDependencies returns a slice of serialized fileDescriptors in
  64  // wire format ([]byte). The fileDescriptors will include fd and all the
  65  // transitive dependencies of fd with names not in sentFileDescriptors.
  66  func (s *ServerReflectionServer) FileDescWithDependencies(fd protoreflect.FileDescriptor, sentFileDescriptors map[string]bool) ([][]byte, error) {
  67  	if fd.IsPlaceholder() {
  68  		// If the given root file is a placeholder, treat it
  69  		// as missing instead of serializing it.
  70  		return nil, protoregistry.NotFound
  71  	}
  72  	var r [][]byte
  73  	queue := []protoreflect.FileDescriptor{fd}
  74  	for len(queue) > 0 {
  75  		currentfd := queue[0]
  76  		queue = queue[1:]
  77  		if currentfd.IsPlaceholder() {
  78  			// Skip any missing files in the dependency graph.
  79  			continue
  80  		}
  81  		if sent := sentFileDescriptors[currentfd.Path()]; len(r) == 0 || !sent {
  82  			sentFileDescriptors[currentfd.Path()] = true
  83  			fdProto := protodesc.ToFileDescriptorProto(currentfd)
  84  			currentfdEncoded, err := proto.Marshal(fdProto)
  85  			if err != nil {
  86  				return nil, err
  87  			}
  88  			r = append(r, currentfdEncoded)
  89  		}
  90  		for i := 0; i < currentfd.Imports().Len(); i++ {
  91  			queue = append(queue, currentfd.Imports().Get(i))
  92  		}
  93  	}
  94  	return r, nil
  95  }
  96  
  97  // FileDescEncodingContainingSymbol finds the file descriptor containing the
  98  // given symbol, finds all of its previously unsent transitive dependencies,
  99  // does marshalling on them, and returns the marshalled result. The given symbol
 100  // can be a type, a service or a method.
 101  func (s *ServerReflectionServer) FileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
 102  	d, err := s.DescResolver.FindDescriptorByName(protoreflect.FullName(name))
 103  	if err != nil {
 104  		return nil, err
 105  	}
 106  	return s.FileDescWithDependencies(d.ParentFile(), sentFileDescriptors)
 107  }
 108  
 109  // FileDescEncodingContainingExtension finds the file descriptor containing
 110  // given extension, finds all of its previously unsent transitive dependencies,
 111  // does marshalling on them, and returns the marshalled result.
 112  func (s *ServerReflectionServer) FileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
 113  	xt, err := s.ExtResolver.FindExtensionByNumber(protoreflect.FullName(typeName), protoreflect.FieldNumber(extNum))
 114  	if err != nil {
 115  		return nil, err
 116  	}
 117  	return s.FileDescWithDependencies(xt.TypeDescriptor().ParentFile(), sentFileDescriptors)
 118  }
 119  
 120  // AllExtensionNumbersForTypeName returns all extension numbers for the given type.
 121  func (s *ServerReflectionServer) AllExtensionNumbersForTypeName(name string) ([]int32, error) {
 122  	var numbers []int32
 123  	s.ExtResolver.RangeExtensionsByMessage(protoreflect.FullName(name), func(xt protoreflect.ExtensionType) bool {
 124  		numbers = append(numbers, int32(xt.TypeDescriptor().Number()))
 125  		return true
 126  	})
 127  	sort.Slice(numbers, func(i, j int) bool {
 128  		return numbers[i] < numbers[j]
 129  	})
 130  	if len(numbers) == 0 {
 131  		// maybe return an error if given type name is not known
 132  		if _, err := s.DescResolver.FindDescriptorByName(protoreflect.FullName(name)); err != nil {
 133  			return nil, err
 134  		}
 135  	}
 136  	return numbers, nil
 137  }
 138  
 139  // ListServices returns the names of services this server exposes.
 140  func (s *ServerReflectionServer) ListServices() []*v1reflectionpb.ServiceResponse {
 141  	serviceInfo := s.S.GetServiceInfo()
 142  	resp := make([]*v1reflectionpb.ServiceResponse, 0, len(serviceInfo))
 143  	for svc := range serviceInfo {
 144  		resp = append(resp, &v1reflectionpb.ServiceResponse{Name: svc})
 145  	}
 146  	sort.Slice(resp, func(i, j int) bool {
 147  		return resp[i].Name < resp[j].Name
 148  	})
 149  	return resp
 150  }
 151  
 152  // ServerReflectionInfo is the reflection service handler.
 153  func (s *ServerReflectionServer) ServerReflectionInfo(stream v1reflectiongrpc.ServerReflection_ServerReflectionInfoServer) error {
 154  	sentFileDescriptors := make(map[string]bool)
 155  	for {
 156  		in, err := stream.Recv()
 157  		if err == io.EOF {
 158  			return nil
 159  		}
 160  		if err != nil {
 161  			return err
 162  		}
 163  
 164  		out := &v1reflectionpb.ServerReflectionResponse{
 165  			ValidHost:       in.Host,
 166  			OriginalRequest: in,
 167  		}
 168  		switch req := in.MessageRequest.(type) {
 169  		case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
 170  			var b [][]byte
 171  			fd, err := s.DescResolver.FindFileByPath(req.FileByFilename)
 172  			if err == nil {
 173  				b, err = s.FileDescWithDependencies(fd, sentFileDescriptors)
 174  			}
 175  			if err != nil {
 176  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
 177  					ErrorResponse: &v1reflectionpb.ErrorResponse{
 178  						ErrorCode:    int32(codes.NotFound),
 179  						ErrorMessage: err.Error(),
 180  					},
 181  				}
 182  			} else {
 183  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
 184  					FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
 185  				}
 186  			}
 187  		case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
 188  			b, err := s.FileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
 189  			if err != nil {
 190  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
 191  					ErrorResponse: &v1reflectionpb.ErrorResponse{
 192  						ErrorCode:    int32(codes.NotFound),
 193  						ErrorMessage: err.Error(),
 194  					},
 195  				}
 196  			} else {
 197  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
 198  					FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
 199  				}
 200  			}
 201  		case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
 202  			typeName := req.FileContainingExtension.ContainingType
 203  			extNum := req.FileContainingExtension.ExtensionNumber
 204  			b, err := s.FileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
 205  			if err != nil {
 206  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
 207  					ErrorResponse: &v1reflectionpb.ErrorResponse{
 208  						ErrorCode:    int32(codes.NotFound),
 209  						ErrorMessage: err.Error(),
 210  					},
 211  				}
 212  			} else {
 213  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
 214  					FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
 215  				}
 216  			}
 217  		case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
 218  			extNums, err := s.AllExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
 219  			if err != nil {
 220  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
 221  					ErrorResponse: &v1reflectionpb.ErrorResponse{
 222  						ErrorCode:    int32(codes.NotFound),
 223  						ErrorMessage: err.Error(),
 224  					},
 225  				}
 226  			} else {
 227  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
 228  					AllExtensionNumbersResponse: &v1reflectionpb.ExtensionNumberResponse{
 229  						BaseTypeName:    req.AllExtensionNumbersOfType,
 230  						ExtensionNumber: extNums,
 231  					},
 232  				}
 233  			}
 234  		case *v1reflectionpb.ServerReflectionRequest_ListServices:
 235  			out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ListServicesResponse{
 236  				ListServicesResponse: &v1reflectionpb.ListServiceResponse{
 237  					Service: s.ListServices(),
 238  				},
 239  			}
 240  		default:
 241  			return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
 242  		}
 243  
 244  		if err := stream.Send(out); err != nil {
 245  			return err
 246  		}
 247  	}
 248  }
 249  
 250  // V1ToV1AlphaResponse converts a v1 ServerReflectionResponse to a v1alpha.
 251  func V1ToV1AlphaResponse(v1 *v1reflectionpb.ServerReflectionResponse) *v1alphareflectionpb.ServerReflectionResponse {
 252  	var v1alpha v1alphareflectionpb.ServerReflectionResponse
 253  	v1alpha.ValidHost = v1.ValidHost
 254  	if v1.OriginalRequest != nil {
 255  		v1alpha.OriginalRequest = V1ToV1AlphaRequest(v1.OriginalRequest)
 256  	}
 257  	switch mr := v1.MessageResponse.(type) {
 258  	case *v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse:
 259  		if mr != nil {
 260  			v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_FileDescriptorResponse{
 261  				FileDescriptorResponse: &v1alphareflectionpb.FileDescriptorResponse{
 262  					FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(),
 263  				},
 264  			}
 265  		}
 266  	case *v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse:
 267  		if mr != nil {
 268  			v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
 269  				AllExtensionNumbersResponse: &v1alphareflectionpb.ExtensionNumberResponse{
 270  					BaseTypeName:    mr.AllExtensionNumbersResponse.GetBaseTypeName(),
 271  					ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(),
 272  				},
 273  			}
 274  		}
 275  	case *v1reflectionpb.ServerReflectionResponse_ListServicesResponse:
 276  		if mr != nil {
 277  			svcs := make([]*v1alphareflectionpb.ServiceResponse, len(mr.ListServicesResponse.GetService()))
 278  			for i, svc := range mr.ListServicesResponse.GetService() {
 279  				svcs[i] = &v1alphareflectionpb.ServiceResponse{
 280  					Name: svc.GetName(),
 281  				}
 282  			}
 283  			v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ListServicesResponse{
 284  				ListServicesResponse: &v1alphareflectionpb.ListServiceResponse{
 285  					Service: svcs,
 286  				},
 287  			}
 288  		}
 289  	case *v1reflectionpb.ServerReflectionResponse_ErrorResponse:
 290  		if mr != nil {
 291  			v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ErrorResponse{
 292  				ErrorResponse: &v1alphareflectionpb.ErrorResponse{
 293  					ErrorCode:    mr.ErrorResponse.GetErrorCode(),
 294  					ErrorMessage: mr.ErrorResponse.GetErrorMessage(),
 295  				},
 296  			}
 297  		}
 298  	default:
 299  		// no value set
 300  	}
 301  	return &v1alpha
 302  }
 303  
 304  // V1AlphaToV1Request converts a v1alpha ServerReflectionRequest to a v1.
 305  func V1AlphaToV1Request(v1alpha *v1alphareflectionpb.ServerReflectionRequest) *v1reflectionpb.ServerReflectionRequest {
 306  	var v1 v1reflectionpb.ServerReflectionRequest
 307  	v1.Host = v1alpha.Host
 308  	switch mr := v1alpha.MessageRequest.(type) {
 309  	case *v1alphareflectionpb.ServerReflectionRequest_FileByFilename:
 310  		v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileByFilename{
 311  			FileByFilename: mr.FileByFilename,
 312  		}
 313  	case *v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol:
 314  		v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingSymbol{
 315  			FileContainingSymbol: mr.FileContainingSymbol,
 316  		}
 317  	case *v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension:
 318  		if mr.FileContainingExtension != nil {
 319  			v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingExtension{
 320  				FileContainingExtension: &v1reflectionpb.ExtensionRequest{
 321  					ContainingType:  mr.FileContainingExtension.GetContainingType(),
 322  					ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
 323  				},
 324  			}
 325  		}
 326  	case *v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
 327  		v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
 328  			AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
 329  		}
 330  	case *v1alphareflectionpb.ServerReflectionRequest_ListServices:
 331  		v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_ListServices{
 332  			ListServices: mr.ListServices,
 333  		}
 334  	default:
 335  		// no value set
 336  	}
 337  	return &v1
 338  }
 339  
 340  // V1ToV1AlphaRequest converts a v1 ServerReflectionRequest to a v1alpha.
 341  func V1ToV1AlphaRequest(v1 *v1reflectionpb.ServerReflectionRequest) *v1alphareflectionpb.ServerReflectionRequest {
 342  	var v1alpha v1alphareflectionpb.ServerReflectionRequest
 343  	v1alpha.Host = v1.Host
 344  	switch mr := v1.MessageRequest.(type) {
 345  	case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
 346  		if mr != nil {
 347  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileByFilename{
 348  				FileByFilename: mr.FileByFilename,
 349  			}
 350  		}
 351  	case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
 352  		if mr != nil {
 353  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol{
 354  				FileContainingSymbol: mr.FileContainingSymbol,
 355  			}
 356  		}
 357  	case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
 358  		if mr != nil {
 359  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension{
 360  				FileContainingExtension: &v1alphareflectionpb.ExtensionRequest{
 361  					ContainingType:  mr.FileContainingExtension.GetContainingType(),
 362  					ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
 363  				},
 364  			}
 365  		}
 366  	case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
 367  		if mr != nil {
 368  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
 369  				AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
 370  			}
 371  		}
 372  	case *v1reflectionpb.ServerReflectionRequest_ListServices:
 373  		if mr != nil {
 374  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_ListServices{
 375  				ListServices: mr.ListServices,
 376  			}
 377  		}
 378  	default:
 379  		// no value set
 380  	}
 381  	return &v1alpha
 382  }
 383  
 384  // V1AlphaToV1Response converts a v1alpha ServerReflectionResponse to a v1.
 385  func V1AlphaToV1Response(v1alpha *v1alphareflectionpb.ServerReflectionResponse) *v1reflectionpb.ServerReflectionResponse {
 386  	var v1 v1reflectionpb.ServerReflectionResponse
 387  	v1.ValidHost = v1alpha.ValidHost
 388  	if v1alpha.OriginalRequest != nil {
 389  		v1.OriginalRequest = V1AlphaToV1Request(v1alpha.OriginalRequest)
 390  	}
 391  	switch mr := v1alpha.MessageResponse.(type) {
 392  	case *v1alphareflectionpb.ServerReflectionResponse_FileDescriptorResponse:
 393  		if mr != nil {
 394  			v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
 395  				FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{
 396  					FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(),
 397  				},
 398  			}
 399  		}
 400  	case *v1alphareflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse:
 401  		if mr != nil {
 402  			v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
 403  				AllExtensionNumbersResponse: &v1reflectionpb.ExtensionNumberResponse{
 404  					BaseTypeName:    mr.AllExtensionNumbersResponse.GetBaseTypeName(),
 405  					ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(),
 406  				},
 407  			}
 408  		}
 409  	case *v1alphareflectionpb.ServerReflectionResponse_ListServicesResponse:
 410  		if mr != nil {
 411  			svcs := make([]*v1reflectionpb.ServiceResponse, len(mr.ListServicesResponse.GetService()))
 412  			for i, svc := range mr.ListServicesResponse.GetService() {
 413  				svcs[i] = &v1reflectionpb.ServiceResponse{
 414  					Name: svc.GetName(),
 415  				}
 416  			}
 417  			v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ListServicesResponse{
 418  				ListServicesResponse: &v1reflectionpb.ListServiceResponse{
 419  					Service: svcs,
 420  				},
 421  			}
 422  		}
 423  	case *v1alphareflectionpb.ServerReflectionResponse_ErrorResponse:
 424  		if mr != nil {
 425  			v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
 426  				ErrorResponse: &v1reflectionpb.ErrorResponse{
 427  					ErrorCode:    mr.ErrorResponse.GetErrorCode(),
 428  					ErrorMessage: mr.ErrorResponse.GetErrorMessage(),
 429  				},
 430  			}
 431  		}
 432  	default:
 433  		// no value set
 434  	}
 435  	return &v1
 436  }
 437