channel_membership.go raw

   1  package app
   2  
   3  import (
   4  	"context"
   5  	"encoding/json"
   6  	"sync"
   7  	"time"
   8  
   9  	"next.orly.dev/pkg/database"
  10  	"next.orly.dev/pkg/lol/chk"
  11  	"next.orly.dev/pkg/lol/log"
  12  	"next.orly.dev/pkg/nostr/encoders/event"
  13  	"next.orly.dev/pkg/nostr/encoders/filter"
  14  	hexenc "next.orly.dev/pkg/nostr/encoders/hex"
  15  	"next.orly.dev/pkg/nostr/encoders/kind"
  16  	"next.orly.dev/pkg/nostr/encoders/tag"
  17  )
  18  
  19  // channelAccessInfo holds parsed channel access control data from kind 41 metadata.
  20  type channelAccessInfo struct {
  21  	creator    string // hex pubkey of channel creator
  22  	accessMode string // "open", "whitelist", "blacklist"
  23  	mods       map[string]bool
  24  	members    map[string]bool // whitelisted members
  25  	blocked    map[string]bool // blacklisted users
  26  	invited    map[string]bool // pending invites (have access)
  27  	rejected   map[string]bool // rejected requests (no access)
  28  	cachedAt   time.Time
  29  }
  30  
  31  const channelCacheTTL = 30 * time.Second
  32  
  33  // channelRefCacheEntry caches whether an event ID references a channel event.
  34  type channelRefCacheEntry struct {
  35  	channelIDHex string
  36  	isChannel    bool
  37  	cachedAt     time.Time
  38  }
  39  
  40  // ChannelMembership manages channel access control lookups with caching.
  41  type ChannelMembership struct {
  42  	db       database.Database
  43  	cache    sync.Map // map[string]*channelAccessInfo (channel ID hex → info)
  44  	refCache sync.Map // map[string]*channelRefCacheEntry (event ID hex → channel ref info)
  45  }
  46  
  47  // NewChannelMembership creates a new channel membership checker.
  48  func NewChannelMembership(db database.Database) *ChannelMembership {
  49  	return &ChannelMembership{db: db}
  50  }
  51  
  52  // InvalidateChannel removes a channel's cached access info, forcing a re-fetch
  53  // on the next check. Call this when a new kind 41 event is ingested.
  54  func (cm *ChannelMembership) InvalidateChannel(channelIDHex string) {
  55  	cm.cache.Delete(channelIDHex)
  56  }
  57  
  58  // IsChannelMember checks whether the given pubkey (binary) is allowed to access
  59  // channel events of the given kind. Returns true if access is granted.
  60  //
  61  // Access rules:
  62  //   - Kinds 40, 41 (create, metadata): always allowed for any authenticated user (discovery)
  63  //   - Kinds 42-44 (message, hide, mute): depends on channel access mode
  64  //   - "open": all authenticated users allowed
  65  //   - "whitelist": only creator, mods, members, and invited users
  66  //   - "blacklist": everyone except blocked and rejected users
  67  func (cm *ChannelMembership) IsChannelMember(
  68  	ev *event.E,
  69  	userPubkey []byte,
  70  	ctx context.Context,
  71  ) bool {
  72  	if len(userPubkey) == 0 {
  73  		return false
  74  	}
  75  
  76  	// Kinds 40 and 41 are always readable for discovery
  77  	if kind.IsDiscoverableChannelKind(ev.Kind) {
  78  		return true
  79  	}
  80  
  81  	// For kinds 42-44, extract channel ID from #e tag
  82  	channelIDHex := extractChannelID(ev)
  83  	if channelIDHex == "" {
  84  		// No channel reference — allow (might be malformed, let other checks handle)
  85  		return true
  86  	}
  87  
  88  	userHex := hexenc.Enc(userPubkey)
  89  
  90  	info, err := cm.getChannelInfo(ctx, channelIDHex)
  91  	if err != nil || info == nil {
  92  		// If we can't determine channel info, allow access (fail open for now)
  93  		log.D.F("channel membership check: no info for channel %s, allowing", channelIDHex)
  94  		return true
  95  	}
  96  
  97  	// Creator always has access
  98  	if info.creator == userHex {
  99  		return true
 100  	}
 101  
 102  	// Mods always have access
 103  	if info.mods[userHex] {
 104  		return true
 105  	}
 106  
 107  	switch info.accessMode {
 108  	case "whitelist":
 109  		return info.members[userHex] || info.invited[userHex]
 110  	case "blacklist":
 111  		return !info.blocked[userHex] && !info.rejected[userHex]
 112  	default: // "open"
 113  		return true
 114  	}
 115  }
 116  
 117  // IsChannelMemberByID checks membership using a channel ID directly (not from an event).
 118  // Used by the publisher when delivering events.
 119  func (cm *ChannelMembership) IsChannelMemberByID(
 120  	channelIDHex string,
 121  	eventKind uint16,
 122  	userPubkey []byte,
 123  	ctx context.Context,
 124  ) bool {
 125  	if len(userPubkey) == 0 {
 126  		return false
 127  	}
 128  
 129  	if kind.IsDiscoverableChannelKind(eventKind) {
 130  		return true
 131  	}
 132  
 133  	if channelIDHex == "" {
 134  		return true
 135  	}
 136  
 137  	userHex := hexenc.Enc(userPubkey)
 138  
 139  	info, err := cm.getChannelInfo(ctx, channelIDHex)
 140  	if err != nil || info == nil {
 141  		return true
 142  	}
 143  
 144  	if info.creator == userHex {
 145  		return true
 146  	}
 147  
 148  	if info.mods[userHex] {
 149  		return true
 150  	}
 151  
 152  	switch info.accessMode {
 153  	case "whitelist":
 154  		return info.members[userHex] || info.invited[userHex]
 155  	case "blacklist":
 156  		return !info.blocked[userHex] && !info.rejected[userHex]
 157  	default:
 158  		return true
 159  	}
 160  }
 161  
 162  // ReferencesChannelEvent checks whether any e-tag in the event references a
 163  // restricted channel event (kind 42-44). If so, returns the channel ID and true.
 164  // Used to enforce channel membership for non-channel kinds (reactions, reposts,
 165  // reports, zaps, deletions) that reference channel events.
 166  func (cm *ChannelMembership) ReferencesChannelEvent(
 167  	ev *event.E,
 168  	ctx context.Context,
 169  ) (channelIDHex string, isChannel bool) {
 170  	if ev.Tags == nil {
 171  		return "", false
 172  	}
 173  	eTags := ev.Tags.GetAll([]byte("e"))
 174  	if len(eTags) == 0 {
 175  		return "", false
 176  	}
 177  
 178  	for _, et := range eTags {
 179  		if et.Len() < 2 {
 180  			continue
 181  		}
 182  		refIDHex := string(et.ValueHex())
 183  		if refIDHex == "" {
 184  			continue
 185  		}
 186  
 187  		// Check reference cache first
 188  		if cached, ok := cm.refCache.Load(refIDHex); ok {
 189  			entry := cached.(*channelRefCacheEntry)
 190  			if time.Since(entry.cachedAt) < channelCacheTTL {
 191  				if entry.isChannel {
 192  					return entry.channelIDHex, true
 193  				}
 194  				continue
 195  			}
 196  		}
 197  
 198  		// Look up the referenced event in the database
 199  		refIDBytes, err := hexenc.Dec(refIDHex)
 200  		if err != nil {
 201  			continue
 202  		}
 203  		ser, err := cm.db.GetSerialById(refIDBytes)
 204  		if err != nil || ser == nil {
 205  			// Cache negative result
 206  			cm.refCache.Store(refIDHex, &channelRefCacheEntry{
 207  				cachedAt: time.Now(),
 208  			})
 209  			continue
 210  		}
 211  		refEv, err := cm.db.FetchEventBySerial(ser)
 212  		if err != nil || refEv == nil {
 213  			cm.refCache.Store(refIDHex, &channelRefCacheEntry{
 214  				cachedAt: time.Now(),
 215  			})
 216  			continue
 217  		}
 218  
 219  		if kind.IsChannelKind(refEv.Kind) && !kind.IsDiscoverableChannelKind(refEv.Kind) {
 220  			// It's a restricted channel event (42-44). Extract the channel ID.
 221  			chID := extractChannelID(refEv)
 222  			cm.refCache.Store(refIDHex, &channelRefCacheEntry{
 223  				channelIDHex: chID,
 224  				isChannel:    true,
 225  				cachedAt:     time.Now(),
 226  			})
 227  			return chID, true
 228  		}
 229  
 230  		// Not a channel event — cache that too
 231  		cm.refCache.Store(refIDHex, &channelRefCacheEntry{
 232  			cachedAt: time.Now(),
 233  		})
 234  	}
 235  	return "", false
 236  }
 237  
 238  // getChannelInfo fetches (from cache or DB) the access control info for a channel.
 239  func (cm *ChannelMembership) getChannelInfo(
 240  	ctx context.Context,
 241  	channelIDHex string,
 242  ) (*channelAccessInfo, error) {
 243  	// Check cache
 244  	if cached, ok := cm.cache.Load(channelIDHex); ok {
 245  		info := cached.(*channelAccessInfo)
 246  		if time.Since(info.cachedAt) < channelCacheTTL {
 247  			return info, nil
 248  		}
 249  		// Expired, fall through to re-fetch
 250  	}
 251  
 252  	// Query for latest kind 41 metadata event for this channel
 253  	f := filter.New()
 254  	f.Kinds = kind.NewS(kind.ChannelMetadata)
 255  
 256  	// Build #e tag filter for the channel ID
 257  	channelIDBytes, err := hexenc.Dec(channelIDHex)
 258  	if err != nil {
 259  		return nil, err
 260  	}
 261  	eTag := tag.NewFromBytesSlice([]byte("e"), channelIDBytes)
 262  	f.Tags = tag.NewSWithCap(1)
 263  	*f.Tags = append(*f.Tags, eTag)
 264  
 265  	limit := uint(1)
 266  	f.Limit = &limit
 267  
 268  	queryCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
 269  	defer cancel()
 270  
 271  	events, err := cm.db.QueryEvents(queryCtx, f)
 272  	if chk.E(err) {
 273  		return nil, err
 274  	}
 275  
 276  	var info *channelAccessInfo
 277  
 278  	if len(events) > 0 {
 279  		info = parseChannelMetadata(events[0])
 280  	} else {
 281  		// No kind 41 found. Try kind 40 (channel creation) to get creator.
 282  		f2 := filter.New()
 283  		f2.Ids = tag.NewFromBytesSlice(channelIDBytes)
 284  		f2.Kinds = kind.NewS(kind.ChannelCreation)
 285  		limit2 := uint(1)
 286  		f2.Limit = &limit2
 287  
 288  		events2, err2 := cm.db.QueryEvents(queryCtx, f2)
 289  		if chk.E(err2) || len(events2) == 0 {
 290  			return nil, err2
 291  		}
 292  
 293  		// Default to open if no metadata exists
 294  		info = &channelAccessInfo{
 295  			creator:    hexenc.Enc(events2[0].Pubkey),
 296  			accessMode: "open",
 297  			mods:       make(map[string]bool),
 298  			members:    make(map[string]bool),
 299  			blocked:    make(map[string]bool),
 300  			invited:    make(map[string]bool),
 301  			rejected:   make(map[string]bool),
 302  		}
 303  	}
 304  
 305  	info.cachedAt = time.Now()
 306  	cm.cache.Store(channelIDHex, info)
 307  	return info, nil
 308  }
 309  
 310  // parseChannelMetadata extracts access control info from a kind 41 event.
 311  func parseChannelMetadata(ev *event.E) *channelAccessInfo {
 312  	info := &channelAccessInfo{
 313  		creator:    hexenc.Enc(ev.Pubkey),
 314  		accessMode: "open",
 315  		mods:       make(map[string]bool),
 316  		members:    make(map[string]bool),
 317  		blocked:    make(map[string]bool),
 318  		invited:    make(map[string]bool),
 319  		rejected:   make(map[string]bool),
 320  	}
 321  
 322  	// Parse content JSON for access_mode
 323  	if len(ev.Content) > 0 {
 324  		var content struct {
 325  			AccessMode string `json:"access_mode"`
 326  			InviteOnly bool   `json:"invite_only"` // backward compat
 327  		}
 328  		if err := json.Unmarshal(ev.Content, &content); err == nil {
 329  			if content.AccessMode != "" {
 330  				info.accessMode = content.AccessMode
 331  			} else if content.InviteOnly {
 332  				info.accessMode = "whitelist"
 333  			}
 334  		}
 335  	}
 336  
 337  	// Parse p-tags for roles
 338  	pTags := ev.Tags.GetAll([]byte("p"))
 339  	for _, pt := range pTags {
 340  		if pt.Len() < 3 {
 341  			continue
 342  		}
 343  		pkHex := string(pt.ValueHex())
 344  		role := string(pt.T[2])
 345  
 346  		switch role {
 347  		case "mod":
 348  			info.mods[pkHex] = true
 349  		case "member":
 350  			info.members[pkHex] = true
 351  		case "blocked":
 352  			info.blocked[pkHex] = true
 353  		case "invited":
 354  			info.invited[pkHex] = true
 355  		case "requested":
 356  			// Requested users don't have access
 357  		case "rejected":
 358  			info.rejected[pkHex] = true
 359  		}
 360  	}
 361  
 362  	return info
 363  }
 364  
 365  // extractChannelID gets the channel ID (hex) from an event's #e tag.
 366  // For kinds 42-44, the channel reference is in the first #e tag.
 367  func extractChannelID(ev *event.E) string {
 368  	if ev.Tags == nil {
 369  		return ""
 370  	}
 371  	eTags := ev.Tags.GetAll([]byte("e"))
 372  	for _, et := range eTags {
 373  		if et.Len() >= 2 {
 374  			val := et.ValueHex()
 375  			if len(val) > 0 {
 376  				return string(val)
 377  			}
 378  		}
 379  	}
 380  	return ""
 381  }
 382  
 383  // ExtractChannelIDFromEvent is the exported version of extractChannelID
 384  // for use by the publisher.
 385  func ExtractChannelIDFromEvent(ev *event.E) string {
 386  	return extractChannelID(ev)
 387  }
 388  
 389  // IsChannelEvent returns true if the event is a channel kind (40-44).
 390  // Convenience wrapper around kind.IsChannelKind.
 391  func IsChannelEvent(ev *event.E) bool {
 392  	return kind.IsChannelKind(ev.Kind)
 393  }
 394  
 395