SessionHandler.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. package zsshrpc_server
  2. import (
  3. "fmt"
  4. "golang.org/x/crypto/ssh"
  5. "net"
  6. )
  7. type ZSshRpcSessionContext struct {
  8. NetConn net.Conn
  9. ServerConn *ssh.ServerConn
  10. NewChannelChan <-chan ssh.NewChannel
  11. ReqChan <-chan *ssh.Request
  12. OperationHandler ZSshRpcOperationHandler
  13. Logger SvrLogFunc
  14. }
  15. type ZSshRpcChannelContext struct {
  16. Channel ssh.Channel
  17. ReqChan <-chan *ssh.Request
  18. SessionCtx *ZSshRpcSessionContext
  19. Logger SvrLogFunc
  20. RxState ZSshRpcChState
  21. RxMethod ZSshRpcMethod
  22. RxBuf ThreadSafeBuffer
  23. UriStr string
  24. JsonStr string
  25. ResponseData ZSshRpcOperationResponse
  26. }
  27. func HandleNewSession(conn net.Conn, sshcfg *ssh.ServerConfig, logger SvrLogFunc, handler ZSshRpcOperationHandler, IOBlockSize int) {
  28. sconn, schan, reqchan, err := ssh.NewServerConn(conn, sshcfg)
  29. if err != nil {
  30. logger(SvrLogLevel_WARNING, fmt.Sprintf("Failed Handling Session From Client '%v'", conn.RemoteAddr()), err)
  31. return
  32. }
  33. rpcSessCtx := &ZSshRpcSessionContext{
  34. NetConn: conn,
  35. ServerConn: sconn,
  36. NewChannelChan: schan,
  37. ReqChan: reqchan,
  38. OperationHandler: handler,
  39. Logger: logger,
  40. }
  41. go ssh.DiscardRequests(reqchan)
  42. go handleChannels(rpcSessCtx, IOBlockSize)
  43. }
  44. func try(Logger SvrLogFunc) {
  45. if err := recover(); err != nil {
  46. var einf error
  47. switch err.(type) {
  48. case error:
  49. einf = err.(error)
  50. break
  51. default:
  52. einf = nil
  53. break
  54. }
  55. Logger(SvrLogLevel_WARNING, "Internal Error Recovered", einf)
  56. }
  57. }
  58. func handleChannels(ctx *ZSshRpcSessionContext, IOBlockSize int) {
  59. defer try(ctx.Logger)
  60. for newchan := range ctx.NewChannelChan {
  61. go handleChannel(ctx, newchan, IOBlockSize)
  62. }
  63. }
  64. func handleChannel(sessctx *ZSshRpcSessionContext, nch ssh.NewChannel, IOBlockSize int) {
  65. defer try(sessctx.Logger)
  66. if nch.ChannelType() == "zsshrpc" {
  67. ch, req, err := nch.Accept()
  68. if err != nil {
  69. sessctx.Logger(SvrLogLevel_INFO, fmt.Sprintf("Failed Handler Channel From '%v'", sessctx.NetConn.RemoteAddr()), err)
  70. sessctx.ServerConn.Close()
  71. } else {
  72. chctx := &ZSshRpcChannelContext{
  73. SessionCtx: sessctx,
  74. Channel: ch,
  75. ReqChan: req,
  76. Logger: sessctx.Logger,
  77. RxState: ChannelState_IDLE,
  78. }
  79. go handleZSshRpcChannel(chctx, IOBlockSize)
  80. }
  81. } else {
  82. sessctx.Logger(SvrLogLevel_INFO, fmt.Sprintf("Unsupported Channel Type '%s' From Client '%v'", nch.ChannelType(), sessctx.NetConn.RemoteAddr()), nil)
  83. nch.Reject(ssh.UnknownChannelType, "Unsupported Channel Type.")
  84. sessctx.ServerConn.Close()
  85. }
  86. }
  87. func readChannelData(b []byte, recvBufLen int, chctx *ZSshRpcChannelContext) {
  88. for {
  89. rl, err := chctx.Channel.Read(b)
  90. if err != nil || rl == 0 {
  91. break
  92. }
  93. if rl < recvBufLen {
  94. if rl != 0 {
  95. if b[rl-1] == 0 {
  96. chctx.RxBuf.Write(b[:rl-1])
  97. break
  98. } else {
  99. chctx.RxBuf.Write(b[:rl])
  100. }
  101. }
  102. } else {
  103. if b[recvBufLen-1] == 0 {
  104. chctx.RxBuf.Write(b[:recvBufLen-1])
  105. break
  106. } else {
  107. chctx.RxBuf.Write(b)
  108. }
  109. }
  110. }
  111. }
  112. func handleZSshRpcChannel(chctx *ZSshRpcChannelContext, IOBlockSize int) {
  113. defer try(chctx.Logger)
  114. ioblockbuf := make([]byte, IOBlockSize)
  115. waitchan := make(chan int)
  116. go func() {
  117. defer try(chctx.Logger)
  118. for {
  119. readChannelData(ioblockbuf, IOBlockSize, chctx)
  120. //chctx.Logger(SvrLogLevel_DEBUG, fmt.Sprintf("<RECV> [len=%d] %s",chctx.RxBuf.Len(), hex.EncodeToString([]byte(chctx.RxBuf.String()))), nil)
  121. waitchan <- 1
  122. }
  123. //io.Copy(&chctx.RxBuf, chctx.Channel)
  124. }()
  125. for {
  126. req := <-chctx.ReqChan
  127. switch chctx.RxState {
  128. case ChannelState_IDLE:
  129. {
  130. if req.Type == "ZSSHRPC-1.0" {
  131. chctx.RxState = ChannelState_WAIT_URI
  132. chctx.RxBuf.Reset()
  133. }
  134. }
  135. break
  136. case ChannelState_WAIT_URI:
  137. {
  138. //chctx.Logger(SvrLogLevel_DEBUG, fmt.Sprintf("[Before URL Clear]Length Of RxBuf: %d", chctx.RxBuf.Len()), nil)
  139. switch req.Type {
  140. case "CALL":
  141. {
  142. chctx.RxMethod = RpcMethod_CALL
  143. chctx.RxState = ChannelState_WAIT_JSON
  144. <-waitchan
  145. chctx.UriStr = chctx.RxBuf.String()
  146. chctx.RxBuf.Reset()
  147. }
  148. break
  149. case "ADD":
  150. {
  151. chctx.RxMethod = RpcMethod_ADD
  152. chctx.RxState = ChannelState_WAIT_JSON
  153. <-waitchan
  154. chctx.UriStr = chctx.RxBuf.String()
  155. chctx.RxBuf.Reset()
  156. }
  157. break
  158. case "DEL":
  159. {
  160. chctx.RxMethod = RpcMethod_DEL
  161. chctx.RxState = ChannelState_WAIT_JSON
  162. <-waitchan
  163. chctx.UriStr = chctx.RxBuf.String()
  164. chctx.RxBuf.Reset()
  165. }
  166. break
  167. case "GET":
  168. {
  169. chctx.RxMethod = RpcMethod_GET
  170. chctx.RxState = ChannelState_WAIT_JSON
  171. <-waitchan
  172. chctx.UriStr = chctx.RxBuf.String()
  173. chctx.RxBuf.Reset()
  174. }
  175. break
  176. case "SET":
  177. {
  178. chctx.RxMethod = RpcMethod_SET
  179. chctx.RxState = ChannelState_WAIT_JSON
  180. <-waitchan
  181. chctx.UriStr = chctx.RxBuf.String()
  182. chctx.RxBuf.Reset()
  183. }
  184. break
  185. default:
  186. {
  187. chctx.RxState = ChannelState_IDLE
  188. chctx.RxBuf.Reset()
  189. }
  190. break
  191. }
  192. chctx.Logger(SvrLogLevel_DEBUG, fmt.Sprint("Recv URI: ", chctx.UriStr), nil)
  193. //chctx.Logger(SvrLogLevel_DEBUG, fmt.Sprintf("[After URI Clear]Length Of RxBuf: %d", chctx.RxBuf.Len()), nil)
  194. }
  195. break
  196. case ChannelState_WAIT_JSON:
  197. {
  198. if req.Type == "ENDREQ" {
  199. //chctx.Logger(SvrLogLevel_DEBUG, fmt.Sprintf("[Before JSON Clear]Length Of RxBuf: %d", chctx.RxBuf.Len()), nil)
  200. <-waitchan
  201. chctx.JsonStr = chctx.RxBuf.String()
  202. //chctx.Logger(SvrLogLevel_DEBUG, fmt.Sprint("Recv JSON: ", chctx.JsonStr), nil)
  203. chctx.RxBuf.Reset()
  204. //chctx.Logger(SvrLogLevel_DEBUG, fmt.Sprintf("[After JSON Clear]Length Of RxBuf: %d", chctx.RxBuf.Len()), nil)
  205. chctx.RxState = ChannelState_EXEC_HANDLER
  206. req := ZSshRpcOperationRequest{
  207. ChannelContext: chctx,
  208. Method: chctx.RxMethod,
  209. URI: chctx.UriStr,
  210. JSON: chctx.JsonStr,
  211. }
  212. chctx.ResponseData = chctx.SessionCtx.OperationHandler.HandleOperation(req)
  213. rcstring := fmt.Sprintf(
  214. "%d - %s",
  215. chctx.ResponseData.StatusCode,
  216. GetResponseStatusCodeString(chctx.ResponseData.StatusCode),
  217. )
  218. chctx.Channel.Write([]byte(rcstring))
  219. chctx.Channel.Write([]byte{0})
  220. chctx.RxState = ChannelState_WAIT_JSON_READ
  221. } else {
  222. chctx.RxState = ChannelState_IDLE
  223. chctx.RxBuf.Reset()
  224. }
  225. }
  226. break
  227. case ChannelState_EXEC_HANDLER:
  228. {
  229. chctx.RxState = ChannelState_IDLE
  230. chctx.RxBuf.Reset()
  231. }
  232. break
  233. case ChannelState_WAIT_JSON_READ:
  234. if req.Type == "GETJSON" {
  235. chctx.Channel.Write([]byte(chctx.ResponseData.ResponseJSON))
  236. chctx.Channel.Write([]byte{0})
  237. }
  238. chctx.RxState = ChannelState_IDLE
  239. break
  240. }
  241. if req.WantReply {
  242. req.Reply(true, nil)
  243. }
  244. }
  245. }