Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pkg/oatproxy/oatproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type OATProxy struct {
clientMetadata *OAuthClientMetadata
defaultPDS string
public bool
httpClient *http.Client
}

type Config struct {
Expand All @@ -40,6 +41,7 @@ type Config struct {
ClientMetadata *OAuthClientMetadata
DefaultPDS string
Public bool
HTTPClient *http.Client
}

func New(conf *Config) *OATProxy {
Expand All @@ -62,6 +64,11 @@ func New(conf *Config) *OATProxy {
defaultPDS: conf.DefaultPDS,
public: conf.Public,
}
if conf.HTTPClient != nil {
o.httpClient = conf.HTTPClient
} else {
o.httpClient = http.DefaultClient
}
if conf.Lock != nil {
o.lock = conf.Lock
} else {
Expand All @@ -77,7 +84,7 @@ func New(conf *Config) *OATProxy {

o.Echo.GET("/.well-known/oauth-authorization-server", o.HandleOAuthAuthorizationServer)
o.Echo.GET("/.well-known/oauth-protected-resource", o.HandleOAuthProtectedResource)
o.Echo.GET("/xrpc/com.atproto.identity.resolveHandle", HandleComAtprotoIdentityResolveHandle)
o.Echo.GET("/xrpc/com.atproto.identity.resolveHandle", o.HandleComAtprotoIdentityResolveHandle)
o.Echo.POST("/oauth/par", o.HandleOAuthPAR)
o.Echo.GET("/oauth/authorize", o.HandleOAuthAuthorize)
o.Echo.GET("/oauth/return", o.HandleOAuthReturn)
Expand Down
2 changes: 1 addition & 1 deletion pkg/oatproxy/oauth_1_par.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (o *OATProxy) NewPAR(ctx context.Context, c echo.Context, par *PAR, dpopHea
var expectedURL *url.URL
var upstreamAuthServerURL string
if o.public {
_, service, httpErr := ResolveHandleAndService(ctx, par.LoginHint)
_, service, httpErr := ResolveHandleAndServiceWithClient(ctx, par.LoginHint, o.httpClient)
if httpErr != nil {
return nil, httpErr
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/oatproxy/oauth_2_authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,17 @@ func (o *OATProxy) Authorize(ctx context.Context, requestURI, clientID string) (
service = session.Handle
} else {
var httpErr *echo.HTTPError
did, service, httpErr = ResolveHandleAndService(ctx, session.Handle)
did, service, httpErr = ResolveHandleAndServiceWithClient(ctx, session.Handle, o.httpClient)
if httpErr != nil {
return "", httpErr
}
did, err = ResolveHandle(ctx, session.Handle)
did, err = ResolveHandleWithClient(ctx, session.Handle, o.httpClient)
if err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to resolve handle '%s': %s", session.Handle, err))
}

var handle2 string
service, handle2, err = ResolveService(ctx, did)
service, handle2, err = ResolveServiceWithClient(ctx, did, o.httpClient)
if err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to resolve service for DID '%s': %s", did, err))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/oatproxy/oauth_3_return.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (o *OATProxy) Return(ctx context.Context, code string, iss string, state st
session.DownstreamAuthorizationCode = downstreamCode
session.UpstreamScope = itResp.Scope
if session.DID == "" {
_, handle, err := ResolveService(ctx, itResp.Sub)
_, handle, err := ResolveServiceWithClient(ctx, itResp.Sub, o.httpClient)
if err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to resolve service for DID '%s': %s", itResp.Sub, err))
}
Expand Down
38 changes: 24 additions & 14 deletions pkg/oatproxy/resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ import (
"go.opentelemetry.io/otel"
)

// func ResolveHandle(ctx context.Context, handle string) (string, error) {
// return ResolveHandleWithClient(ctx, handle, http.DefaultClient)
// }

// mostly borrowed from github.com/streamplace/atproto-oauth-golang, MIT license
func ResolveHandle(ctx context.Context, handle string) (string, error) {
func ResolveHandleWithClient(ctx context.Context, handle string, client *http.Client) (string, error) {
var did string

_, err := syntax.ParseHandle(handle)
Expand All @@ -39,32 +43,34 @@ func ResolveHandle(ctx context.Context, handle string) (string, error) {
ctx,
"GET",
fmt.Sprintf("https://%s/.well-known/atproto-did", handle),
// "https://webhook.site/ce546544-f7ef-4880-9cbe-c2bf15ad9840",
nil,
)
req.Header.Del("accept-encoding")
if err != nil {
return "", err
return "", fmt.Errorf("unable to resolve handle: failed to create request: %s", err)
}

resp, err := http.DefaultClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
io.Copy(io.Discard, resp.Body)
return "", fmt.Errorf("unable to resolve handle")
return "", fmt.Errorf("unable to resolve handle, got http status %d", resp.StatusCode)
}

b, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
return "", fmt.Errorf("unable to resolve handle: failed to read response body: %s", err)
}

maybeDid := strings.TrimSpace(string(b))

if _, err := syntax.ParseDID(maybeDid); err != nil {
return "", fmt.Errorf("unable to resolve handle")
return "", fmt.Errorf("unable to resolve handle: failed to parse DID: %s", err)
}

did = maybeDid
Expand All @@ -73,7 +79,11 @@ func ResolveHandle(ctx context.Context, handle string) (string, error) {
return did, nil
}

func ResolveService(ctx context.Context, did string) (string, string, error) {
// func ResolveService(ctx context.Context, did string) (string, string, error) {
// return ResolveServiceWithClient(ctx, did, http.DefaultClient)
// }

func ResolveServiceWithClient(ctx context.Context, did string, client *http.Client) (string, string, error) {
type Identity struct {
AlsoKnownAs []string `json:"alsoKnownAs"`
Service []struct {
Expand Down Expand Up @@ -137,14 +147,14 @@ func ResolveService(ctx context.Context, did string) (string, string, error) {
}

// returns did, service
func ResolveHandleAndService(ctx context.Context, handle string) (string, string, *echo.HTTPError) {
did, err := ResolveHandle(ctx, handle)
func ResolveHandleAndServiceWithClient(ctx context.Context, handle string, client *http.Client) (string, string, *echo.HTTPError) {
did, err := ResolveHandleWithClient(ctx, handle, client)
if err != nil {
return "", "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to resolve handle '%s': %s", handle, err))
}

var handle2 string
service, handle2, err := ResolveService(ctx, did)
service, handle2, err := ResolveServiceWithClient(ctx, did, client)
if err != nil {
return "", "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to resolve service for DID '%s': %s", did, err))
}
Expand All @@ -154,22 +164,22 @@ func ResolveHandleAndService(ctx context.Context, handle string) (string, string
return did, service, nil
}

func HandleComAtprotoIdentityResolveHandle(c echo.Context) error {
func (o *OATProxy) HandleComAtprotoIdentityResolveHandle(c echo.Context) error {
ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleComAtprotoIdentityResolveHandle")
defer span.End()
handle := c.QueryParam("handle")
var out *comatprototypes.IdentityResolveHandle_Output
var handleErr error
// func (s *Server) handleComAtprotoIdentityResolveHandle(ctx context.Context,handle string) (*comatprototypes.IdentityResolveHandle_Output, error)
out, handleErr = handleComAtprotoIdentityResolveHandle(ctx, handle)
out, handleErr = handleComAtprotoIdentityResolveHandle(ctx, handle, o.httpClient)
if handleErr != nil {
return handleErr
}
return c.JSON(200, out)
}

func handleComAtprotoIdentityResolveHandle(ctx context.Context, handle string) (*comatprototypes.IdentityResolveHandle_Output, error) {
did, err := ResolveHandle(ctx, handle)
func handleComAtprotoIdentityResolveHandle(ctx context.Context, handle string, client *http.Client) (*comatprototypes.IdentityResolveHandle_Output, error) {
did, err := ResolveHandleWithClient(ctx, handle, client)
if err != nil {
return nil, err
}
Expand Down