Skip to content

Commit f0b3f10

Browse files
SteffenDESteffen Deusch
andauthored
Use serving name as pg group name (#1566)
Co-authored-by: Steffen Deusch <steffen.deusch@teaminternet.com>
1 parent b2fdb9a commit f0b3f10

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

nx/lib/nx/serving.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ defmodule Nx.Serving do
11171117
end
11181118

11191119
defp distributed_batched_run_with_retries!(name, input, retries) do
1120-
case :pg.get_members(Nx.Serving.PG, __MODULE__) do
1120+
case :pg.get_members(Nx.Serving.PG, name) do
11211121
[] ->
11221122
exit({:noproc, {__MODULE__, :distributed_batched_run, [name, input, [retries: retries]]}})
11231123

@@ -1332,7 +1332,7 @@ defmodule Nx.Serving do
13321332
)
13331333

13341334
serving_weight = max(1, weight * partitions_count)
1335-
:pg.join(Nx.Serving.PG, __MODULE__, List.duplicate(self(), serving_weight))
1335+
:pg.join(Nx.Serving.PG, name, List.duplicate(self(), serving_weight))
13361336

13371337
for batch_key <- batch_keys do
13381338
stack_init(batch_key)

nx/test/nx/serving_test.exs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,8 @@ defmodule Nx.ServingTest do
12881288
]
12891289

12901290
Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :multiply, [parent, opts])
1291-
assert_receive {_, :join, Nx.Serving, _}
1291+
assert_receive {_, :join, name, _}
1292+
assert name == config.test
12921293

12931294
batch = Nx.Batch.concatenate([Nx.tensor([1, 2])])
12941295

@@ -1327,14 +1328,16 @@ defmodule Nx.ServingTest do
13271328
opts2 = Keyword.put(opts, :distribution_weight, 4)
13281329

13291330
Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :multiply, [parent, opts])
1330-
assert_receive {_, :join, Nx.Serving, pids}
1331+
assert_receive {_, :join, name, pids}
13311332
assert length(pids) == 1
1333+
assert name == config.test
13321334

13331335
Node.spawn_link(:"tertiary@127.0.0.1", DistributedServings, :multiply, [parent, opts2])
1334-
assert_receive {_, :join, Nx.Serving, pids}
1336+
assert_receive {_, :join, name, pids}
13351337
assert length(pids) == 4
1338+
assert name == config.test
13361339

1337-
members = :pg.get_members(Nx.Serving.PG, Nx.Serving)
1340+
members = :pg.get_members(Nx.Serving.PG, config.test)
13381341
assert length(members) == 5
13391342
end
13401343

@@ -1356,7 +1359,8 @@ defmodule Nx.ServingTest do
13561359

13571360
args = [parent, opts]
13581361
Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :add_five_round_about, args)
1359-
assert_receive {_, :join, Nx.Serving, _}
1362+
assert_receive {_, :join, name, _}
1363+
assert name == config.test
13601364

13611365
batch = Nx.Batch.concatenate([Nx.tensor([1, 2])])
13621366

@@ -1412,7 +1416,8 @@ defmodule Nx.ServingTest do
14121416
]
14131417

14141418
Node.spawn_link(:"tertiary@127.0.0.1", DistributedServings, :multiply, [parent, opts])
1415-
assert_receive {_, :join, Nx.Serving, _}
1419+
assert_receive {_, :join, name, _}
1420+
assert name == config.test
14161421

14171422
batch = Nx.Batch.concatenate([Nx.tensor([1, 2])])
14181423

0 commit comments

Comments
 (0)