package auth import ( "SimpleTutorialHosting/internal/models" "context" "crypto/rsa" "errors" "fmt" "github.com/crewjam/saml/samlsp" "net/http" "net/url" ) type SAMLAuthenticator struct { middleware *samlsp.Middleware adminGroup string viewerGroup string } func NewSAMLAuthenticator(cfg *models.SAMLConfig) (*SAMLAuthenticator, error) { if cfg == nil { return nil, errors.New("saml config is nil") } rootURL, err := url.Parse(cfg.RootURL) if err != nil { return nil, fmt.Errorf("invalid root url: %w", err) } idpMetadata, err := samlsp.ParseMetadata(cfg.IDPMetadata) if err != nil { return nil, fmt.Errorf("invalid idp metadata: %w", err) } opts := samlsp.Options{ URL: *rootURL, Key: cfg.KeyPair.PrivateKey.(*rsa.PrivateKey), Certificate: cfg.KeyPair.Leaf, IDPMetadata: idpMetadata, SignRequest: true, AllowIDPInitiated: true, } middleware, err := samlsp.New(opts) if err != nil { return nil, err } return &SAMLAuthenticator{ middleware: middleware, adminGroup: cfg.AdminGroup, viewerGroup: cfg.ViewerGroup, }, nil } func (a *SAMLAuthenticator) GetMiddleware() http.Handler { mux := http.NewServeMux() // Handle SAML routes mux.HandleFunc("/saml/metadata", func(w http.ResponseWriter, r *http.Request) { a.middleware.ServeMetadata(w, r) }) mux.HandleFunc("/saml/acs", func(w http.ResponseWriter, r *http.Request) { a.middleware.ServeACS(w, r) }) mux.HandleFunc("/saml/sso", func(w http.ResponseWriter, r *http.Request) { a.middleware.ServeHTTP(w, r) }) return mux } func (a *SAMLAuthenticator) Authenticate(next http.Handler) http.Handler { return a.middleware.RequireAccount( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, err := a.GetUser(r) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), UserContextKey, user) next.ServeHTTP(w, r.WithContext(ctx)) }), ) } func (a *SAMLAuthenticator) GetUser(r *http.Request) (*models.User, error) { session := samlsp.SessionFromContext(r.Context()) if session == nil { return nil, errors.New("no session in context") } groups := samlsp.AttributeFromContext(r.Context(), "groups") var role models.Role switch groups { case a.adminGroup: role = models.RoleAdmin case a.viewerGroup: role = models.RoleViewer } if role == "" { return nil, errors.New("user has no valid role") } email := samlsp.AttributeFromContext(r.Context(), "email") if email == "" { return nil, errors.New("no email attribute") } displayName := samlsp.AttributeFromContext(r.Context(), "displayName") if displayName == "" { return nil, errors.New("no displayName attribute") } return &models.User{ ID: email, Username: displayName, Role: role, }, nil }