@@ -6,9 +6,12 @@ import (
66 "fmt"
77 "net"
88 "net/url"
9+ "path"
910 "strconv"
11+ "time"
1012
1113 "github.com/wzshiming/sshd"
14+ "github.com/wzshiming/sshproxy/permissions"
1215 "golang.org/x/crypto/ssh"
1316)
1417
@@ -24,14 +27,15 @@ type SimpleServer struct {
2427
2528// NewSimpleServer creates a new NewSimpleServer
2629func NewSimpleServer (addr string ) (* SimpleServer , error ) {
27- user , pwd , host , config , err := serverConfig (addr )
30+ user , pwd , host , config , userPermissions , err := serverConfig (addr )
2831 if err != nil {
2932 return nil , err
3033 }
3134
3235 s := & SimpleServer {
3336 Server : Server {
34- ServerConfig : * config ,
37+ ServerConfig : * config ,
38+ UserPermissions : userPermissions ,
3539 },
3640 Network : "tcp" ,
3741 Address : host ,
@@ -41,10 +45,10 @@ func NewSimpleServer(addr string) (*SimpleServer, error) {
4145 return s , nil
4246}
4347
44- func serverConfig (addr string ) (host , user , pwd string , config * ssh.ServerConfig , err error ) {
48+ func serverConfig (addr string ) (host , user , pwd string , config * ssh.ServerConfig , userPermissions func ( user string ) sshd. Permissions , err error ) {
4549 ur , err := url .Parse (addr )
4650 if err != nil {
47- return "" , "" , "" , nil , err
51+ return "" , "" , "" , nil , nil , err
4852 }
4953
5054 isPwd := false
@@ -71,45 +75,104 @@ func serverConfig(addr string) (host, user, pwd string, config *ssh.ServerConfig
7175
7276 hostkeyDatas , err := getQuery (ur .Query ()["hostkey_data" ], ur .Query ()["hostkey_file" ])
7377 if err != nil {
74- return "" , "" , "" , nil , err
78+ return "" , "" , "" , nil , nil , err
7579 }
7680 if len (hostkeyDatas ) == 0 {
7781 key , err := sshd .RandomHostkey ()
7882 if err != nil {
79- return "" , "" , "" , nil , err
83+ return "" , "" , "" , nil , nil , err
8084 }
8185 config .AddHostKey (key )
8286 } else {
8387 for _ , data := range hostkeyDatas {
8488 key , err := sshd .ParseHostkey (data )
8589 if err != nil {
86- return "" , "" , "" , nil , err
90+ return "" , "" , "" , nil , nil , err
8791 }
8892 config .AddHostKey (key )
8993 }
9094 }
9195
96+ pks := []func (conn ssh.ConnMetadata , key ssh.PublicKey ) (* ssh.Permissions , error ){}
9297 authorizedDatas , err := getQuery (ur .Query ()["authorized_data" ], ur .Query ()["authorized_file" ])
9398 if err != nil {
94- return "" , "" , "" , nil , err
99+ return "" , "" , "" , nil , nil , err
95100 }
96- allKeys := map [string ]string {}
97- for _ , data := range authorizedDatas {
98- keys , err := sshd .ParseAuthorized (bytes .NewBuffer (data ))
101+ if len (authorizedDatas ) != 0 {
102+ keys , err := sshd .ParseAuthorized (bytes .NewBuffer (bytes .Join (authorizedDatas , []byte {'\n' })))
99103 if err != nil {
100- return "" , "" , "" , nil , err
104+ return "" , "" , "" , nil , nil , err
101105 }
102- for k , v := range keys {
103- allKeys [k ] = v
106+ if len (keys .Data ) != 0 {
107+ pks = append (pks , func (conn ssh.ConnMetadata , key ssh.PublicKey ) (* ssh.Permissions , error ) {
108+ ok , _ := keys .Allow (key )
109+ if ok {
110+ return nil , nil
111+ }
112+ return nil , fmt .Errorf ("denied" )
113+ })
104114 }
105115 }
106- if len (allKeys ) != 0 {
107- config .PublicKeyCallback = func (conn ssh.ConnMetadata , key ssh.PublicKey ) (* ssh.Permissions , error ) {
108- k := string (key .Marshal ())
109- if _ , ok := allKeys [k ]; ok {
116+
117+ homeDirs := ur .Query ()["home_dir" ]
118+ if len (homeDirs ) != 0 && homeDirs [0 ] != "" {
119+ homeDir := homeDirs [0 ]
120+ sshDirName := ".ssh"
121+ sshDirNames := ur .Query ()["ssh_dir_name" ]
122+ if len (sshDirNames ) != 0 {
123+ sshDirName = sshDirNames [0 ]
124+ }
125+ authorizedFileName := "authorized_keys"
126+ authorizedFileNames := ur .Query ()["authorized_file_name" ]
127+ if len (authorizedFileNames ) != 0 {
128+ authorizedFileName = authorizedFileNames [0 ]
129+ }
130+ pks = append (pks , func (conn ssh.ConnMetadata , key ssh.PublicKey ) (* ssh.Permissions , error ) {
131+ file := path .Join (homeDir , conn .User (), sshDirName , authorizedFileName )
132+ keys , err := sshd .GetAuthorizedFile (file )
133+ if err != nil {
134+ return nil , fmt .Errorf ("denied" )
135+ }
136+ ok , _ := keys .Allow (key )
137+ if ok {
110138 return nil , nil
111139 }
112140 return nil , fmt .Errorf ("denied" )
141+ })
142+
143+ // Other sshd implementations do not have such fine-grained permissions control,
144+ // and this is a fine-grained set of permissions control files defined by the project itself
145+ permissionsFileName := ""
146+ permissionsFileNames := ur .Query ()["permissions_file_name" ]
147+ if len (permissionsFileNames ) != 0 {
148+ permissionsFileName = permissionsFileNames [0 ]
149+ }
150+ if permissionsFileName != "" {
151+ permissionsFileUpdatePeriod := time .Duration (0 )
152+ permissionsFileUpdatePeriods := ur .Query ()["permissions_file_update_period" ]
153+ if len (permissionsFileUpdatePeriods ) != 0 {
154+ permissionsFileUpdatePeriod , _ = time .ParseDuration (permissionsFileUpdatePeriods [0 ])
155+ }
156+ userPermissions = func (user string ) sshd.Permissions {
157+ file := path .Join (homeDir , user , sshDirName , permissionsFileName )
158+ return permissions .NewPermissionsFromFile (file , permissionsFileUpdatePeriod )
159+ }
160+ }
161+ }
162+
163+ if len (pks ) != 0 {
164+ if len (pks ) == 1 {
165+ config .PublicKeyCallback = pks [0 ]
166+ } else {
167+ config .PublicKeyCallback = func (conn ssh.ConnMetadata , key ssh.PublicKey ) (p * ssh.Permissions , err error ) {
168+ for _ , pk := range pks {
169+ p , err = pk (conn , key )
170+ if err == nil {
171+ break
172+ }
173+ }
174+ return
175+ }
113176 }
114177 }
115178
@@ -125,7 +188,7 @@ func serverConfig(addr string) (host, user, pwd string, config *ssh.ServerConfig
125188 port = "22"
126189 }
127190 host = net .JoinHostPort (host , port )
128- return user , pwd , host , config , nil
191+ return user , pwd , host , config , userPermissions , nil
129192}
130193
131194// Run the server
0 commit comments