diff --git a/replicate/use.py b/replicate/use.py index 2ea6783..596674e 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -190,7 +190,17 @@ def _resolve_ref(obj: Any) -> Any: result = _resolve_ref(dereferenced) - # Filter out any references that have now been referenced. + # Remove "paths" as these aren't relevant to models. + result["paths"] = {} + + # Retain Input and Output schemas as these are important. + dereferenced_refs.discard("Input") + dereferenced_refs.discard("Output") + + dereferenced_refs.discard("TrainingInput") + dereferenced_refs.discard("TrainingOutput") + + # Filter out any remaining references that have been inlined. result["components"]["schemas"] = { k: v for k, v in result["components"]["schemas"].items() diff --git a/tests/test_use.py b/tests/test_use.py index 734c832..05747a3 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -78,6 +78,82 @@ def create_mock_version(version_overrides=None, version_id="xyz123"): "required": ["prompt"], }, "Output": {"type": "string", "title": "Output"}, + "PredictionResponse": { + "type": "object", + "title": "PredictionResponse", + "properties": { + "id": {"type": "string", "title": "Id"}, + "logs": {"type": "string", "title": "Logs", "default": ""}, + "error": {"type": "string", "title": "Error"}, + "input": {"$ref": "#/components/schemas/Input"}, + "output": {"$ref": "#/components/schemas/Output"}, + "status": {"$ref": "#/components/schemas/Status"}, + "metrics": {"type": "object", "title": "Metrics"}, + "version": {"type": "string", "title": "Version"}, + "created_at": { + "type": "string", + "title": "Created At", + "format": "date-time", + }, + "started_at": { + "type": "string", + "title": "Started At", + "format": "date-time", + }, + "completed_at": { + "type": "string", + "title": "Completed At", + "format": "date-time", + }, + }, + }, + "PredictionRequest": { + "type": "object", + "title": "PredictionRequest", + "properties": { + "id": {"type": "string", "title": "Id"}, + "input": {"$ref": "#/components/schemas/Input"}, + "webhook": { + "type": "string", + "title": "Webhook", + "format": "uri", + "maxLength": 65536, + "minLength": 1, + }, + "created_at": { + "type": "string", + "title": "Created At", + "format": "date-time", + }, + "output_file_prefix": { + "type": "string", + "title": "Output File Prefix", + }, + "webhook_events_filter": { + "type": "array", + "items": {"$ref": "#/components/schemas/WebhookEvent"}, + "default": ["start", "output", "logs", "completed"], + }, + }, + }, + "Status": { + "enum": [ + "starting", + "processing", + "succeeded", + "canceled", + "failed", + ], + "type": "string", + "title": "Status", + "description": "An enumeration.", + }, + "WebhookEvent": { + "enum": ["start", "output", "logs", "completed"], + "type": "string", + "title": "WebhookEvent", + "description": "An enumeration.", + }, } }, }, @@ -345,6 +421,7 @@ async def test_use_function_openapi_schema_dereferenced(client_mode): "openapi_schema": { "components": { "schemas": { + "Extra": {"type": "object"}, "Output": {"$ref": "#/components/schemas/ModelOutput"}, "ModelOutput": { "type": "object", @@ -374,6 +451,12 @@ async def test_use_function_openapi_schema_dereferenced(client_mode): else: schema = hotdog_detector.openapi_schema() + assert schema["components"]["schemas"]["Extra"] == {"type": "object"} + assert schema["components"]["schemas"]["Input"] == { + "type": "object", + "properties": {"prompt": {"type": "string", "title": "Prompt"}}, + "required": ["prompt"], + } assert schema["components"]["schemas"]["Output"] == { "type": "object", "properties": { @@ -386,7 +469,14 @@ async def test_use_function_openapi_schema_dereferenced(client_mode): }, } + # Assert everything else is stripped out + assert schema["paths"] == {} + + assert "PredictionRequest" not in schema["components"]["schemas"] + assert "PredictionResponse" not in schema["components"]["schemas"] assert "ModelOutput" not in schema["components"]["schemas"] + assert "Status" not in schema["components"]["schemas"] + assert "WebhookEvent" not in schema["components"]["schemas"] @pytest.mark.asyncio diff --git a/uv.lock b/uv.lock index 30c17f9..1d664bb 100644 --- a/uv.lock +++ b/uv.lock @@ -1282,7 +1282,7 @@ wheels = [ [[package]] name = "replicate" -version = "1.1.0b1" +version = "1.1.0b2" source = { editable = "." } dependencies = [ { name = "httpx" },