diff --git a/main.go b/main.go index 1835a7a..2e1433c 100644 --- a/main.go +++ b/main.go @@ -1,28 +1,49 @@ package main import ( - "fmt" - "io" - "io/ioutil" - "log" - "net" - + "crypto/tls" + "flag" + "github.com/micro/go-config" + "github.com/micro/go-config/source/env" + "github.com/micro/go-config/source/file" + mflag "github.com/micro/go-config/source/flag" + log "github.com/sirupsen/logrus" + "github.com/wercker/journalhook" "golang.org/x/crypto/ssh" + "io/ioutil" + "net" ) func main() { - // In the latest version of crypto/ssh (after Go 1.3), the SSH server type has been removed - // in favour of an SSH connection type. A ssh.ServerConn is created by passing an existing - // net.Conn and a ssh.ServerConfig to ssh.NewServerConn, in effect, upgrading the net.Conn - // into an ssh.ServerConn + flag.String("ssh-privkey", "id_rsa", "ssh private key") + flag.String("ssh-listen", "0.0.0.0:22", "ssh listen address") + flag.String("tls-cert", "", "tls certificate") + flag.String("tls-key", "", "tls certificate") + flag.String("tls-listen", "", "tls listen address") + flag.String("log-mode", "", "Logging mode (std, journal)") + flag.Parse() - config := &ssh.ServerConfig{ - NoClientAuth: true, + config.Load( + file.NewSource( + file.WithPath("/etc/fwd.json"), + ), + env.NewSource(), + mflag.NewSource(), + ) + + if config.Get("log", "mode").String("std") == "journal" { + journalhook.Enable() + } + + sshConfig := &ssh.ServerConfig{ + NoClientAuth: !config.Get("ssh", "auth").Bool(false), } // You can generate a keypair with 'ssh-keygen -t rsa' - privateBytes, err := ioutil.ReadFile("id_rsa") + key := config.Get("ssh", "privkey").String("id_rsa") + + privateBytes, err := ioutil.ReadFile(key) if err != nil { log.Fatal("Failed to load private key (./id_rsa)") } @@ -32,151 +53,43 @@ func main() { log.Fatal("Failed to parse private key") } - config.AddHostKey(private) + sshConfig.AddHostKey(private) - // Once a ServerConfig has been configured, connections can be accepted. - listener, err := net.Listen("tcp", "0.0.0.0:2200") - if err != nil { - log.Fatalf("Failed to listen on 2200 (%s)", err) + if config.Get("tls", "listen").String("") != "" { + cer, err := tls.LoadX509KeyPair( + config.Get("tls", "cert").String(""), + config.Get("tls", "key").String(""), + ) + if err != nil { + log.Println(err) + return + } + tlsConfig := &tls.Config{Certificates: []tls.Certificate{cer}} + tlsListener, err := tls.Listen("tcp", config.Get("tls", "listen").String(""), tlsConfig) + if err != nil { + log.Fatalf("TLS listen failed: %s", err) + } + + go handleListener(tlsListener, "tls", sshConfig) } - // Accept all connections - log.Print("Listening on 2200...") + listener, err := net.Listen("tcp", config.Get("ssh", "listen").String("0.0.0.0:22")) + if err != nil { + log.Fatalf("SSH listen failed: %s", err) + } + + handleListener(listener, "ssh", sshConfig) +} + +func handleListener(listener net.Listener, backend string, sshConfig *ssh.ServerConfig) { + log.Printf("Listening on %s (%s)...", listener.Addr(), backend) for { tcpConn, err := listener.Accept() if err != nil { - log.Printf("Failed to accept incoming connection (%s)", err) - continue - } - // Before use, a handshake must be performed on the incoming net.Conn. - sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config) - if err != nil { - log.Printf("Failed to handshake (%s)", err) + log.Printf("accept failed: %s", err) continue } - log.Printf("New SSH connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion()) - go handleRequests(reqs, sshConn) - go handleChannels(chans) + go handleConnection(tcpConn, backend, sshConfig) } } - -type tcpIpForwardRequestPayload struct { - Raddr string - Rport uint32 -} - -type forwardedTcpIpRequestPayload struct { - Raddr string - Rport uint32 - Laddr string - Lport uint32 -} - -func handleRequests(reqs <-chan *ssh.Request, conn *ssh.ServerConn) { - for req := range reqs { - log.Printf("req: %s", req.Type) - switch req.Type { - case "tcpip-forward": - var payload tcpIpForwardRequestPayload - - err := ssh.Unmarshal(req.Payload, &payload) - if err != nil { - log.Printf("Malormed forward request") - continue - } - - log.Printf("%v", payload) - - listener, err := net.Listen("tcp", "0.0.0.0:0") - if err != nil { - log.Printf("Failed to listen") - req.Reply(false, nil) - continue - } - defer listener.Close() - - log.Printf("Listening on %v", listener.Addr()) - - go func() { - for { - lconn, err := listener.Accept() - log.Printf("%v %v", lconn, err) - - channelPayload := ssh.Marshal(&forwardedTcpIpRequestPayload{ - Raddr: payload.Raddr, - Rport: payload.Rport, - Laddr: "localhost", - Lport: 8000, - }) - channel, reqs, err := conn.OpenChannel("forwarded-tcpip", channelPayload) - if err != nil { - log.Printf("%s: open channel failed", err) - return - } - go func() { - defer lconn.Close() - io.Copy(lconn, channel) - }() - go func() { - defer channel.Close() - io.Copy(channel, lconn) - }() - go ssh.DiscardRequests(reqs) - } - }() - - req.Reply(true, nil) - case "keepalive@openssh.com": - req.Reply(true, nil) - default: - log.Printf("Ignoring request") - req.Reply(false, nil) - } - } -} - -func handleChannels(chans <-chan ssh.NewChannel) { - // Service the incoming Channel channel in go routine - for newChannel := range chans { - go handleChannel(newChannel) - } -} - -func handleChannel(newChannel ssh.NewChannel) { - // Since we're handling a shell, we expect a - // channel type of "session". The also describes - // "x11", "direct-tcpip" and "forwarded-tcpip" - // channel types. - if t := newChannel.ChannelType(); t != "session" { - newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) - return - } - - // At this point, we have the opportunity to reject the client's - // request for another logical connection - connection, requests, err := newChannel.Accept() - if err != nil { - log.Printf("Could not accept channel (%s)", err) - return - } - - // Sessions have out-of-band requests such as "shell", "pty-req" and "env" - go func() { - for req := range requests { - log.Printf(req.Type) - switch req.Type { - case "shell": - // We only accept the default shell - // (i.e. no command in the Payload) - if len(req.Payload) == 0 { - req.Reply(true, nil) - } - } - } - }() - connection.Write([]byte("hello!\r\n")) - io.Copy(connection, connection) - log.Printf("Session closed") - return -} diff --git a/ssh.go b/ssh.go new file mode 100644 index 0000000..5e7e709 --- /dev/null +++ b/ssh.go @@ -0,0 +1,224 @@ +package main + +import ( + "fmt" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "io" + "net" +) + +func handleConnection(tcpConn net.Conn, backend string, sshConfig *ssh.ServerConfig) { + logger := log.WithFields(log.Fields{ + "client": tcpConn.RemoteAddr(), + "backend": backend, + }) + logger.Debug("Incoming connection") + + // Before use, a handshake must be performed on the incoming net.Conn. + sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, sshConfig) + if err != nil { + logger.WithFields(log.Fields{ + "err": err, + }).Warning("SSH Handshake failed") + return + } + + logger.WithFields(log.Fields{ + "version": string(sshConn.ClientVersion()), + "user": sshConn.User(), + }).Info("Incoming SSH connection") + + go handleRequests(reqs, sshConn) + go handleChannels(chans, sshConn) + go func() { + err := sshConn.Wait() + logger.WithFields(log.Fields{ + "err": err, + }).Info("Connection closed") + }() +} + +type tcpIpForwardRequestPayload struct { + Raddr string + Rport uint32 +} + +type forwardedTcpIpRequestPayload struct { + Raddr string + Rport uint32 + Laddr string + Lport uint32 +} + +func verifyForwardRequest(conn *ssh.ServerConn, payload tcpIpForwardRequestPayload) bool { + return payload.Rport >= 1024 +} + +func handleRequests(reqs <-chan *ssh.Request, conn *ssh.ServerConn) { + + for req := range reqs { + logger := log.WithFields(log.Fields{ + "client": conn.RemoteAddr(), + "request": req.Type, + }) + + switch req.Type { + case "tcpip-forward": + var payload tcpIpForwardRequestPayload + err := ssh.Unmarshal(req.Payload, &payload) + if err != nil { + logger.Warn("Malormed forward request") + continue + } + + if !verifyForwardRequest(conn, payload) { + req.Reply(false, nil) + continue + } + + listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", payload.Rport)) + if err != nil { + logger.Warn("Failed to listen: %s", err) + req.Reply(false, nil) + continue + } + defer listener.Close() + + logger.WithFields(log.Fields{ + "forward": listener.Addr(), + }).Info("Forwarding") + + go handleForward(listener, conn, payload) + + req.Reply(true, nil) + case "keepalive@openssh.com": + req.Reply(true, nil) + default: + logger.Debug("unhandled request") + req.Reply(false, nil) + } + } +} + +func handleForward(listener net.Listener, conn *ssh.ServerConn, payload tcpIpForwardRequestPayload) { + logger := log.WithFields(log.Fields{ + "client": conn.RemoteAddr(), + "forward": listener.Addr(), + }) + + for { + lconn, err := listener.Accept() + if err != nil { + logger.Warn("accept failed: %s", err) + return + } + logger.WithFields(log.Fields{ + "remote": lconn.RemoteAddr(), + }).Info("Forwarding connection") + + channelPayload := ssh.Marshal(&forwardedTcpIpRequestPayload{ + Raddr: payload.Raddr, + Rport: payload.Rport, + Laddr: "localhost", + Lport: 1, + }) + + channel, reqs, err := conn.OpenChannel("forwarded-tcpip", channelPayload) + if err != nil { + logger.Warn("open channel failed: %s", err) + lconn.Close() + continue + } + + go func() { + defer lconn.Close() + io.Copy(lconn, channel) + }() + + go func() { + defer channel.Close() + io.Copy(channel, lconn) + }() + + go ssh.DiscardRequests(reqs) + } +} + +func handleChannels(chans <-chan ssh.NewChannel, conn *ssh.ServerConn) { + // Service the incoming Channel channel in go routine + for newChannel := range chans { + go handleChannel(newChannel, conn) + } +} + +func handleChannel(newChannel ssh.NewChannel, conn *ssh.ServerConn) { + logger := log.WithFields(log.Fields{ + "client": conn.RemoteAddr(), + }) + + // Since we're handling a shell, we expect a + // channel type of "session". The also describes + // "x11", "direct-tcpip" and "forwarded-tcpip" + // channel types. + if t := newChannel.ChannelType(); t != "session" { + newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + return + } + + // At this point, we have the opportunity to reject the client's + // request for another logical connection + connection, requests, err := newChannel.Accept() + if err != nil { + logger.Error("Could not accept channel: %s", err) + return + } + + // Sessions have out-of-band requests such as "shell", "pty-req" and "env" + go func() { + for req := range requests { + logger := log.WithFields(log.Fields{ + "client": conn.RemoteAddr(), + "request": req.Type, + }) + switch req.Type { + case "shell": + // We only accept the default shell + // (i.e. no command in the Payload) + if len(req.Payload) == 0 { + req.Reply(true, nil) + } + case "exec": + logger.Debug("exec: %v", req.Payload) + var payload struct { + Command string + } + + err := ssh.Unmarshal(req.Payload, &payload) + if err != nil { + log.Error("Malormed exec request") + continue + } + logger.Debug("exec: %v", payload) + case "env": + var payload struct { + Name string + Value string + } + + err := ssh.Unmarshal(req.Payload, &payload) + if err != nil { + log.Error("Malormed env request") + continue + } + logger.Printf("env: %v", payload) + default: + logger.Debug("unhandled session request: %s", req.Type) + } + } + }() + connection.Write([]byte("hello!\r\n")) + io.Copy(connection, connection) + logger.Info("Session closed") + return +}