diff --git a/pkg/cmd/login/login.go b/pkg/cmd/login/login.go index ffc32150..afccca49 100644 --- a/pkg/cmd/login/login.go +++ b/pkg/cmd/login/login.go @@ -34,6 +34,7 @@ type LoginStore interface { GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) GetActiveOrganizationOrDefault() (*entity.Organization, error) CreateOrganization(req store.CreateOrganizationRequest) (*entity.Organization, error) + SetDefaultOrganization(org *entity.Organization) error GetServerSockFile() string GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error) UpdateUser(userID string, updatedUser *entity.UpdateUser) (*entity.User, error) @@ -58,6 +59,7 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra. var skipBrowser bool var emailFlag string var authProviderFlag string + var orgFlag string cmd := &cobra.Command{ Annotations: map[string]string{"housekeeping": ""}, @@ -82,7 +84,7 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra. }, Args: cmderrors.TransformToValidationError(cobra.NoArgs), RunE: func(cmd *cobra.Command, args []string) error { - err := opts.RunLogin(t, loginToken, skipBrowser, emailFlag, authProviderFlag) + err := opts.RunLogin(t, loginToken, skipBrowser, emailFlag, authProviderFlag, orgFlag) if err != nil { // if err is ImportIDEConfigError, log err with sentry but continue if _, ok := err.(*importideconfig.ImportIDEConfigError); !ok { @@ -102,6 +104,7 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra. cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "print url instead of auto opening browser") cmd.Flags().StringVar(&emailFlag, "email", "", "email to use for authentication") cmd.Flags().StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is nvidia)") + cmd.Flags().StringVarP(&orgFlag, "org", "o", "", "organization to use (must exist)") return cmd } @@ -130,27 +133,50 @@ func (o LoginOptions) loginAndGetOrCreateUser(loginToken string, skipBrowser boo return user, nil } -func (o LoginOptions) getOrCreateOrg(username string) (*entity.Organization, error) { - org, err := o.LoginStore.GetActiveOrganizationOrDefault() - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } +func (o LoginOptions) getOrCreateOrg(username string, orgFlag string) (*entity.Organization, error) { + var org *entity.Organization + var err error - if org == nil { - newOrgName := makeFirstOrgName(username) - fmt.Printf("Creating your first org %s ...\n", newOrgName) - org, err = o.LoginStore.CreateOrganization(store.CreateOrganizationRequest{ - Name: newOrgName, - }) + if orgFlag != "" { + var orgs []entity.Organization + orgs, err = o.LoginStore.GetOrganizations(&store.GetOrganizationsOptions{Name: orgFlag}) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if len(orgs) == 0 { + return nil, breverrors.NewValidationError(fmt.Sprintf("no org found with name %s", orgFlag)) + } else if len(orgs) > 1 { + return nil, breverrors.NewValidationError(fmt.Sprintf("more than one org found with name %s", orgFlag)) + } + org = &orgs[0] + + err = o.LoginStore.SetDefaultOrganization(org) if err != nil { return nil, breverrors.WrapAndTrace(err) } - fmt.Println("done!") + } else { + org, err = o.LoginStore.GetActiveOrganizationOrDefault() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + if org == nil { + newOrgName := makeFirstOrgName(username) + fmt.Printf("Creating your first org %s ...\n", newOrgName) + org, err = o.LoginStore.CreateOrganization(store.CreateOrganizationRequest{ + Name: newOrgName, + }) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + fmt.Println("done!") + } } + return org, nil } -func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool, emailFlag string, authProviderFlag string) error { +func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool, emailFlag string, authProviderFlag string, orgFlag string) error { tokens, _ := o.LoginStore.GetAuthTokens() if authProviderFlag != "" && authProviderFlag != "nvidia" && authProviderFlag != "legacy" { @@ -175,7 +201,7 @@ func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrow return breverrors.WrapAndTrace(err) } - org, err := o.getOrCreateOrg(user.Username) + org, err := o.getOrCreateOrg(user.Username, orgFlag) if err != nil { return breverrors.WrapAndTrace(err) }