diff --git a/client.go b/client.go index 656e003c..c9c39b04 100644 --- a/client.go +++ b/client.go @@ -387,7 +387,7 @@ func (c *Client) ReadStreamRange(path string, offset, length int64) (io.ReadClos // Write writes data to a given path func (c *Client) Write(path string, data []byte, _ os.FileMode) (err error) { - s, err := c.put(path, bytes.NewReader(data)) + s, err := c.put(path, bytes.NewReader(data), int64(len(data))) if err != nil { return } @@ -403,7 +403,7 @@ func (c *Client) Write(path string, data []byte, _ os.FileMode) (err error) { return } - s, err = c.put(path, bytes.NewReader(data)) + s, err = c.put(path, bytes.NewReader(data), int64(len(data))) if err != nil { return } @@ -423,7 +423,29 @@ func (c *Client) WriteStream(path string, stream io.Reader, _ os.FileMode) (err return err } - s, err := c.put(path, stream) + contentLength := int64(0) + if seeker, ok := stream.(io.Seeker); ok { + contentLength, err = seeker.Seek(0, io.SeekEnd) + if err != nil { + return err + } + + _, err = seeker.Seek(0, io.SeekStart) + if err != nil { + return err + } + } else { + buffer := bytes.NewBuffer(make([]byte, 0, 1024 * 1024 /* 1MB */)) + + contentLength, err = io.Copy(buffer, stream) + if err != nil { + return err + } + + stream = buffer + } + + s, err := c.put(path, stream, contentLength) if err != nil { return err } diff --git a/client_test.go b/client_test.go index 65724fbf..016d9acf 100644 --- a/client_test.go +++ b/client_test.go @@ -39,6 +39,25 @@ func basicAuth(h http.Handler) http.HandlerFunc { } } +func basicAuthWithPostHandlerFunc(h http.Handler, postHandlerFunc http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user, passwd, ok := r.BasicAuth() + if !ok { + w.Header().Set("WWW-Authenticate", `Basic realm="x"`) + w.WriteHeader(401) + return + } + + if user != "user" || passwd != "password" { + http.Error(w, "not authorized", 403) + return + } + + h.ServeHTTP(w, r) + postHandlerFunc(w, r) + } +} + func multipleAuth(h http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { notAuthed := false @@ -130,6 +149,38 @@ func newAuthSrv(t *testing.T, auth func(h http.Handler) http.HandlerFunc) (*http return srv, fs, ctx } +func newAuthServerAcquireContentLength(t *testing.T) (*Client, *httptest.Server, webdav.FileSystem, context.Context) { + srv, fs, ctx := newAuthSrvAcquireContentLength(t, basicAuthWithPostHandlerFunc) + cli := NewClient(srv.URL, "user", "password") + return cli, srv, fs, ctx +} + +func newAuthSrvAcquireContentLength(t *testing.T, authWithPostHandlerFunc func(h http.Handler, postHandlerFunc http.HandlerFunc) http.HandlerFunc) (*httptest.Server, webdav.FileSystem, context.Context) { + mux := http.NewServeMux() + fs := webdav.NewMemFS() + ctx := fillFs(t, fs) + mux.HandleFunc("/", authWithPostHandlerFunc(&webdav.Handler{ + FileSystem: fs, + LockSystem: webdav.NewMemLS(), + }, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + return + } + + fileName := strings.TrimPrefix(r.URL.Path, "/") + stat, err := fs.Stat(ctx, fileName) + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + + if r.ContentLength != stat.Size() { + t.Fatalf("acquire content length got: %v, want %v", r.ContentLength, stat.Size()) + } + })) + srv := httptest.NewServer(mux) + return srv, fs, ctx +} + func TestConnect(t *testing.T) { cli, srv, _, _ := newServer(t) defer srv.Close() @@ -572,3 +623,21 @@ func TestWriteStreamFromPipe(t *testing.T) { t.Fatalf("got: %v, want file size: %d bytes", info.Size(), 8) } } + +func TestWriteToServerAcquireContentLength(t *testing.T) { + cli, srv, _, _ := newAuthServerAcquireContentLength(t) + defer srv.Close() + + if err := cli.Write("/newfile.txt", []byte("foo bar\n"), 0660); err != nil { + t.Fatalf("got: %v, want nil", err) + } +} + +func TestWriteStreamToServerAcquireContentLength(t *testing.T) { + cli, srv, _, _ := newAuthServerAcquireContentLength(t) + defer srv.Close() + + if err := cli.WriteStream("/newfile.txt", strings.NewReader("foo bar\n"), 0660); err != nil { + t.Fatalf("got: %v, want nil", err) + } +} diff --git a/requests.go b/requests.go index 8e362e86..b51e5c04 100644 --- a/requests.go +++ b/requests.go @@ -160,8 +160,10 @@ func (c *Client) copymove(method string, oldpath string, newpath string, overwri return NewPathError(method, oldpath, s) } -func (c *Client) put(path string, stream io.Reader) (status int, err error) { - rs, err := c.req("PUT", path, stream, nil) +func (c *Client) put(path string, stream io.Reader, contentLength int64) (status int, err error) { + rs, err := c.req("PUT", path, stream, func(r *http.Request) { + r.ContentLength = contentLength + }) if err != nil { return }