Skip to content

Commit

Permalink
Various improvements in server
Browse files Browse the repository at this point in the history
  • Loading branch information
synw committed Jan 18, 2025
1 parent 03f55e5 commit bddf355
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 9 deletions.
4 changes: 2 additions & 2 deletions server/conf/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func InitConf() types.Conf {
viper.SetConfigName("server.config")
viper.AddConfigPath(".")
viper.SetDefault("origins", []string{"localhost"})
viper.SetDefault("cmd_api_key", nil)
viper.SetDefault("cmd_api_key", "")
viper.SetDefault("models", map[string]string{})
viper.SetDefault("features", []string{})
err := viper.ReadInConfig() // Find and read the config file
Expand All @@ -31,7 +31,7 @@ func InitConf() types.Conf {
return types.Conf{
Origins: or,
ApiKey: ak,
CmdApiKey: &cmdak,
CmdApiKey: cmdak,
Features: ft,
Models: models,
}
Expand Down
94 changes: 94 additions & 0 deletions server/httpserver/cmd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package httpserver

import (
"encoding/json"
"fmt"
"net/http"
"sync"

"github.com/labstack/echo/v4"
"github.com/synw/agent-smith/server/lm"
"github.com/synw/agent-smith/server/state"
"github.com/synw/agent-smith/server/types"
)

func ExecuteCmdHandler(c echo.Context) error {
m := echo.Map{}
if err := c.Bind(&m); err != nil {
return err
}
cmd, ok := m["cmd"]
if !ok {
msg := "Provide a 'cmd' string parameter"
return echo.NewHTTPError(http.StatusBadRequest, msg)
}
params, ok := m["params"]
if !ok {
params = []string{}
}

c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
c.Response().WriteHeader(http.StatusOK)
ch := make(chan types.StreamedMessage)
errCh := make(chan types.StreamedMessage)

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
rawParams := params.([]interface{})
params := lm.InterfaceToStringArray(rawParams)
lm.RunCmd(cmd.(string), params, c, ch, errCh)
}()

select {
case res, ok := <-ch:
if ok {
if state.IsDebug {
fmt.Println("-------- result ----------")
for key, value := range res.Data {
fmt.Printf("%s: %v\n", key, value)
}
fmt.Println("--------------------------")
}
}
wg.Wait()
close(ch)
close(errCh)
return nil
case err, ok := <-errCh:
if ok {
enc := json.NewEncoder(c.Response())
err := lm.StreamMsg(err, c, enc)
if err != nil {
if state.IsDebug {
fmt.Println("Streaming error", err)
errCh <- types.StreamedMessage{
Content: "Streaming error",
MsgType: types.ErrorMsgType,
}
}
wg.Wait()
close(ch)
close(errCh)
return c.NoContent(http.StatusInternalServerError)
}
} else {
wg.Wait()
close(ch)
close(errCh)
return c.JSON(http.StatusInternalServerError, err)
}
wg.Wait()
close(ch)
close(errCh)
return nil
case <-c.Request().Context().Done():
fmt.Println("\nRequest canceled")
state.ContinueInferingController = false
wg.Wait()
close(ch)
close(errCh)
return c.NoContent(http.StatusNoContent)
}
}
17 changes: 14 additions & 3 deletions server/httpserver/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (
"github.com/labstack/gommon/log"
)

func RunServer(origins []string, apiKey string) {
func RunServer(origins []string, apiKey string, cmdApîKey string) {
e := echo.New()

// logger
e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
Format: "${method} ${status} ${uri} ${latency_human} ${remote_ip} ${error}\n",
Format: "${method} ${status} ${uri} ${latency_human} ${remote_ip} ${error}\n",
}))
if l, ok := e.Logger.(*log.Logger); ok {
l.SetHeader("[${time_rfc3339}] ${level}")
Expand All @@ -29,7 +29,7 @@ func RunServer(origins []string, apiKey string) {

tasks := e.Group("/task")
tasks.Use(middleware.KeyAuth(func(key string, c echo.Context) (bool, error) {
if key == apiKey {
if (key == apiKey) || (key == cmdApîKey) {
//c.Set("apiKey", key)
return true, nil
}
Expand All @@ -38,5 +38,16 @@ func RunServer(origins []string, apiKey string) {
//tasks.GET("/abort", AbortHandler)
tasks.POST("/execute", ExecuteTaskHandler)

/*cmds := e.Group("/cmd")
cmds.Use(middleware.KeyAuth(func(key string, c echo.Context) (bool, error) {
if key == cmdApîKey {
//c.Set("apiKey", key)
return true, nil
}
return false, nil
}))
//tasks.GET("/abort", AbortHandler)
cmds.POST("/execute", ExecuteCmdHandler)*/

e.Start(":5143")
}
11 changes: 9 additions & 2 deletions server/httpserver/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@ func ExecuteTaskHandler(c echo.Context) error {
taskName := ""
if ok {
taskName = v.(string)
} else {
msg := "Provide a 'task' string parameter"
return echo.NewHTTPError(http.StatusBadRequest, msg)
}
v, ok = m["prompt"]
prompt := ""
if ok {
prompt = v.(string)
} else {
msg := "Provide a 'prompt' string parameter"
return echo.NewHTTPError(http.StatusBadRequest, msg)
}
v, ok = m["vars"]
vars := make(map[string]interface{})
Expand All @@ -38,10 +44,11 @@ func ExecuteTaskHandler(c echo.Context) error {
}
found, _, tp := state.GetTask(taskName)
if !found {
msg := "Task " + taskName + " not found"
if state.IsDebug {
fmt.Println("Task", taskName, "not found")
fmt.Println(msg)
}
return c.NoContent(http.StatusBadRequest)
return echo.NewHTTPError(http.StatusBadRequest, msg)
}
ms, ok := state.ModelsConf[taskName]
if !ok {
Expand Down
68 changes: 68 additions & 0 deletions server/lm/cmd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package lm

import (
"bufio"
"fmt"
"os/exec"

"github.com/labstack/echo/v4"
"github.com/synw/agent-smith/server/types"
)

func RunCmd(
cmdName string,
params []string,
c echo.Context,
ch chan<- types.StreamedMessage,
errCh chan<- types.StreamedMessage,
) {
// Create the command with the arguments
params = append([]string{cmdName}, params...)
fmt.Println("Params:")
for _, p := range params {
fmt.Println("-", p)
}
fmt.Println("Cmd", "lm", params)
cmd := exec.Command("lm", params...)

// Create a pipe to capture the command's output
stdout, err := cmd.StdoutPipe()
if err != nil {
msg := fmt.Errorf("Error creating stdout pipe:", err)
errCh <- createErrorMsg(msg.Error())
return
}

// Start the command
if err := cmd.Start(); err != nil {
msg := fmt.Errorf("Error starting command:", err)
errCh <- createErrorMsg(msg.Error())
return
}

// Create a scanner to read the output word by word
scanner := bufio.NewScanner(stdout)
scanner.Split(bufio.ScanWords) // Set the split function to scan words

// Read and print the output word by word
i := 0
for scanner.Scan() {
i++
token := scanner.Text()
fmt.Println("T", token)
ch <- createMsg(token, i)
}

// Check for errors during scanning
if err := scanner.Err(); err != nil {
msg := fmt.Errorf("Error reading output:", err)
fmt.Println(msg)
errCh <- createErrorMsg(msg.Error())
}

// Wait for the command to finish
if err := cmd.Wait(); err != nil {
msg := fmt.Errorf("Command finished with error:", err)
errCh <- createErrorMsg(msg.Error())
}
}
31 changes: 31 additions & 0 deletions server/lm/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package lm
import (
"fmt"
"reflect"

"github.com/synw/agent-smith/server/types"
)

func StructToMap(s interface{}) (map[string]interface{}, error) {
Expand All @@ -26,3 +28,32 @@ func StructToMap(s interface{}) (map[string]interface{}, error) {

return m, nil
}

func InterfaceToStringArray(interfaceSlice []interface{}) []string {
// Convert to slice of string
stringSlice := make([]string, 0, len(interfaceSlice))
for _, v := range interfaceSlice {
if str, ok := v.(string); ok {
stringSlice = append(stringSlice, str)
} else {
// Handle the case where the element is not a string
fmt.Printf("Skipping non-string element in interfaceToStringArray: %v\n", v)
}
}
return stringSlice
}

func createMsg(msg string, n int) types.StreamedMessage {
return types.StreamedMessage{
Content: msg,
MsgType: types.TokenMsgType,
Num: n,
}
}

func createErrorMsg(msg string) types.StreamedMessage {
return types.StreamedMessage{
Content: msg,
MsgType: types.ErrorMsgType,
}
}
3 changes: 2 additions & 1 deletion server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ func main() {
}
state.IsVerbose = !*quiet
config := conf.InitConf()
//fmt.Println("Conf", config)
err := state.Init(config.Features, config.Models)
if err != nil {
log.Fatal("Error initializing state", err)
}
if state.IsVerbose {
fmt.Println("Starting the http server with allowed origins", config.Origins)
}
httpserver.RunServer(config.Origins, config.ApiKey)
httpserver.RunServer(config.Origins, config.ApiKey, config.CmdApiKey)
}
2 changes: 1 addition & 1 deletion server/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package types
type Conf struct {
Origins []string
ApiKey string
CmdApiKey *string
CmdApiKey string
Features []string
Models map[string]string
}
Expand Down

0 comments on commit bddf355

Please sign in to comment.