-
Notifications
You must be signed in to change notification settings - Fork 74
Add ability to pull vllm-compatible hf models #535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @ericcurtin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the model pulling capabilities by integrating native support for HuggingFace models. It allows the system to directly interact with HuggingFace Hub, download raw model files like safetensors and associated configurations, and then assemble them into OCI-compliant model artifacts. This new functionality acts as a seamless fallback, ensuring that models hosted on HuggingFace can be consumed even if they don't explicitly provide an OCI manifest, thereby broadening the range of supported models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - I've found 2 issues, and left some high level feedback:
- In
Downloader.DownloadAll, the progress-reporting goroutine never exits unlessctxis canceled, but you(<-progressDone)after the downloads finish; this will deadlock for normal successful pulls—consider signaling the goroutine to stop when all downloads are complete instead of only relying onctx.Done. - In
BuildModel/createConfigArchive, when there are only chat-template config files,createConfigArchivereturns an empty path (""), which is still passed toos.RemoveandWithConfigArchive; you should either skip callingWithConfigArchive(andRemove) when the archive path is empty or havecreateConfigArchivereturn a non-empty path only when it actually created an archive. - The
tempDirparameter tocreateConfigArchiveis currently unused; either remove it from the signature or use it to control where the temporary archive is created to avoid confusion about where intermediate files are written.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `Downloader.DownloadAll`, the progress-reporting goroutine never exits unless `ctx` is canceled, but you `(<-progressDone)` after the downloads finish; this will deadlock for normal successful pulls—consider signaling the goroutine to stop when all downloads are complete instead of only relying on `ctx.Done`.
- In `BuildModel`/`createConfigArchive`, when there are only chat-template config files, `createConfigArchive` returns an empty path (`""`), which is still passed to `os.Remove` and `WithConfigArchive`; you should either skip calling `WithConfigArchive` (and `Remove`) when the archive path is empty or have `createConfigArchive` return a non-empty path only when it actually created an archive.
- The `tempDir` parameter to `createConfigArchive` is currently unused; either remove it from the signature or use it to control where the temporary archive is created to avoid confusion about where intermediate files are written.
## Individual Comments
### Comment 1
<location> `pkg/distribution/huggingface/downloader.go:60` </location>
<code_context>
+ localPaths := make(map[string]string, len(files))
+
+ // Progress reporting goroutine
+ progressDone := make(chan struct{})
+ if progressWriter != nil {
+ go func() {
</code_context>
<issue_to_address>
**issue (bug_risk):** Progress goroutine never terminates, causing DownloadAll to block when waiting on progressDone.
The goroutine only terminates on `ctx.Done`, while `DownloadAll` waits on `<-progressDone` after `wg.Wait()`. If the caller never cancels the context and `progressWriter` is non-nil, this will deadlock. You could add an internal `done` channel or use a `context.WithCancel` inside `DownloadAll`, cancel it after `wg.Wait()`, and have the goroutine also select on that signal so it exits once downloads complete.
</issue_to_address>
### Comment 2
<location> `pkg/distribution/huggingface/client_test.go:60-69` </location>
<code_context>
+ }
+}
+
+func TestClientDownloadFile(t *testing.T) {
+ expectedContent := "test file content"
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/test-org/test-model/resolve/main/test.txt" {
+ w.Header().Set("Content-Length", "17")
+ w.Write([]byte(expectedContent))
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer server.Close()
+
+ client := NewClient(WithBaseURL(server.URL))
+
+ reader, size, err := client.DownloadFile(context.Background(), "test-org/test-model", "main", "test.txt")
+ if err != nil {
+ t.Fatalf("DownloadFile failed: %v", err)
+ }
+ defer reader.Close()
+
+ if size != 17 {
+ t.Errorf("Expected size 17, got %d", size)
+ }
+
+ content, err := io.ReadAll(reader)
+ if err != nil {
+ t.Fatalf("ReadAll failed: %v", err)
+ }
+
+ if string(content) != expectedContent {
+ t.Errorf("Expected content %q, got %q", expectedContent, string(content))
+ }
</code_context>
<issue_to_address>
**suggestion (testing):** Consider adding tests for additional response codes and headers in the HuggingFace client
Current tests cover the happy path and auth/not-found cases, but `client.checkResponse` also handles `StatusTooManyRequests` and a generic fallback that reads up to 1KB of the body. Please add tests to:
- Assert `DownloadFile`/`ListFiles` return `RateLimitError` for 429 responses.
- Assert an unexpected status (e.g., 500) returns a generic error that includes the truncated body.
- Optionally, assert `DownloadFile` defaults the revision to `"main"` when an empty revision is passed, mirroring `TestClientListFilesDefaultRevision`.
This will exercise all `checkResponse` branches and the default-revision behavior.
Suggested implementation:
```golang
defer server.Close()
}
func TestClientDownloadFile_RateLimit(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate HuggingFace rate limiting
w.WriteHeader(http.StatusTooManyRequests)
}))
defer server.Close()
client := NewClient(WithBaseURL(server.URL))
reader, _, err := client.DownloadFile(context.Background(), "test-org/test-model", "main", "test.txt")
if reader != nil {
// In case DownloadFile returns a reader before noticing the error,
// make sure we close it to avoid leaks in the test.
reader.Close()
}
if err == nil {
t.Fatal("expected RateLimitError, got nil")
}
if _, ok := err.(*RateLimitError); !ok {
t.Fatalf("expected RateLimitError, got %T: %v", err, err)
}
}
func TestClientListFiles_RateLimit(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate HuggingFace rate limiting
if r.URL.Path == "/api/models/test-org/test-model/tree/main" {
w.WriteHeader(http.StatusTooManyRequests)
return
}
http.NotFound(w, r)
}))
defer server.Close()
client := NewClient(WithBaseURL(server.URL))
_, err := client.ListFiles(context.Background(), "test-org/test-model", "main")
if err == nil {
t.Fatal("expected RateLimitError, got nil")
}
if _, ok := err.(*RateLimitError); !ok {
t.Fatalf("expected RateLimitError, got %T: %v", err, err)
}
}
func TestClientListFiles_UnexpectedStatusIncludesBody(t *testing.T) {
const partialBody = "unexpected error body from server"
// Make a body longer than 1KB so we exercise the truncation path.
longBody := partialBody
for len(longBody) < 2048 {
longBody += " " + partialBody
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/models/test-org/test-model/tree/main" {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(longBody))
return
}
http.NotFound(w, r)
}))
defer server.Close()
client := NewClient(WithBaseURL(server.URL))
_, err := client.ListFiles(context.Background(), "test-org/test-model", "main")
if err == nil {
t.Fatal("expected error for 500 response, got nil")
}
msg := err.Error()
// The generic branch should include at least the beginning of the body.
if !strings.Contains(msg, partialBody) {
t.Errorf("expected error message to contain response body %q, got %q", partialBody, msg)
}
// Optionally (depending on checkResponse implementation), we can also assert
// that the error message does not contain the full long body, to ensure it was truncated.
if len(longBody) > 1024 && strings.Contains(msg, longBody) {
t.Errorf("expected response body to be truncated in error message, but full body was present")
}
}
func TestClientDownloadFile_DefaultRevision(t *testing.T) {
expectedContent := "test file content"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// When revision is empty, DownloadFile should default to "main".
if r.URL.Path != "/test-org/test-model/resolve/main/test.txt" {
t.Errorf("unexpected request path: got %q, want %q", r.URL.Path, "/test-org/test-model/resolve/main/test.txt")
}
w.Header().Set("Content-Length", "17")
_, _ = w.Write([]byte(expectedContent))
}))
defer server.Close()
client := NewClient(WithBaseURL(server.URL))
reader, size, err := client.DownloadFile(context.Background(), "test-org/test-model", "", "test.txt")
if err != nil {
t.Fatalf("DownloadFile failed: %v", err)
}
defer reader.Close()
if size != 17 {
t.Errorf("Expected size 17, got %d", size)
}
content, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("ReadAll failed: %v", err)
}
if string(content) != expectedContent {
t.Errorf("Expected content %q, got %q", expectedContent, string(content))
}
```
1. Ensure the file imports `strings` (and that `RateLimitError` is available in this package):
- Update the import block to include:
- `"strings"`
For example:
```go
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"strings"
)
```
Adjust to match the existing imports and ordering in `client_test.go`.
2. The `RateLimitError` type name is inferred from `client.checkResponse` behavior described in your comment. If the concrete type is named differently (e.g. `*rateLimitError` or defined in another package), update the type assertion in the two `RateLimit` tests accordingly.
3. The `TestClientListFiles_UnexpectedStatusIncludesBody` assertions assume that `checkResponse`:
- Includes at least the first part of the body in `err.Error()`.
- Truncates bodies larger than 1KB so the full body is not present in the error string.
If your `checkResponse` formats the error message differently (e.g. prefixes with `"huggingface: "` or wraps the body in quotes), you may need to tweak the `strings.Contains` expectations (for example, checking for a shorter substring or adjusting the truncation assertion) to match the actual error text.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-structured feature to natively pull HuggingFace models that are not in OCI format. The changes include a new HuggingFace client, a parallel downloader, and integration into the existing pull mechanism. The code is generally of high quality with good separation of concerns and includes comprehensive tests.
My review focuses on improving error handling robustness and test coverage to make the new functionality even more solid. Specifically, I've pointed out a potential silent failure, a couple of places where errors are ignored, and opportunities to improve test completeness and idiomatic Go in tests.
2e7c686 to
98336ea
Compare
98336ea to
60a81b2
Compare
This commit introduces native HuggingFace model support by adding a new HuggingFace client implementation that can download safetensors files directly from HuggingFace Hub repositories. The changes include: A new HuggingFace client with authentication, file listing, and download capabilities. The client handles LFS files, error responses, and rate limiting appropriately. A downloader component that manages parallel file downloads with progress reporting and temporary file storage. It includes progress tracking and concurrent download limiting. Model building functionality that downloads files from HuggingFace repositories and constructs OCI model artifacts using the existing builder framework. Repository utilities for file classification, filtering, and size calculations to identify safetensors and config files needed for model construction. Integration with the existing pull mechanism to detect HuggingFace references and attempt native pulling when no OCI manifest is found. This preserves existing OCI functionality while adding fallback support for raw HuggingFace repositories. Signed-off-by: Eric Curtin <eric.curtin@docker.com>
60a81b2 to
4dc8ce5
Compare
This commit introduces native HuggingFace model support by adding a new HuggingFace client implementation that can download safetensors files directly from HuggingFace Hub repositories. The changes include:
A new HuggingFace client with authentication, file listing, and download capabilities. The client handles LFS files, error responses, and rate limiting appropriately.
A downloader component that manages parallel file downloads with progress reporting and temporary file storage. It includes progress tracking and concurrent download limiting.
Model building functionality that downloads files from HuggingFace repositories and constructs OCI model artifacts using the existing builder framework.
Repository utilities for file classification, filtering, and size calculations to identify safetensors and config files needed for model construction.
Integration with the existing pull mechanism to detect HuggingFace references and attempt native pulling when no OCI manifest is found. This preserves existing OCI functionality while adding fallback support for raw HuggingFace repositories.