Skip to content

Commit 8a5ee2c

Browse files
committed
feat: send bytecode and signature to device given handle_info message
1 parent 3803a05 commit 8a5ee2c

File tree

7 files changed

+102
-9
lines changed

7 files changed

+102
-9
lines changed

lib/nx_iree/compiler.ex

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@ defmodule NxIREE.Compiler do
33
Compiler for Nx defn
44
"""
55

6+
def to_bytecode(fun, templates, opts \\ []) do
7+
opts = opts |> Keyword.put(:output_mode, :bytecode) |> Keyword.put(:compiler, __MODULE__)
8+
9+
Nx.Defn.compile(fun, templates, opts)
10+
catch
11+
{:bytecode, %{bytecode: bytecode, output_container: output_container}} ->
12+
{:ok, %{bytecode: bytecode, output_container: output_container}}
13+
end
14+
615
@behaviour Nx.Defn.Compiler
716

817
@impl true
@@ -22,9 +31,7 @@ defmodule NxIREE.Compiler do
2231
bytecode = NxIREE.compile(mlir_module, iree_compiler_flags)
2332

2433
if output_mode == :bytecode do
25-
fn _ ->
26-
[bytecode]
27-
end
34+
throw({:bytecode, %{bytecode: bytecode, output_container: output_container}})
2835
else
2936
fn [inputs] ->
3037
{:ok, results} =

liveview_native/live_nx_iree/config/config.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Config
99

1010
config :live_nx_iree,
1111
namespace: LiveNxIREE,
12-
ecto_repos: [LiveNxIREE.Repo],
12+
ecto_repos: [],
1313
generators: [timestamp_type: :utc_datetime]
1414

1515
# Configures the endpoint

liveview_native/live_nx_iree/lib/live_nx_iree/application.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ defmodule LiveNxIREE.Application do
99
def start(_type, _args) do
1010
children = [
1111
LiveNxIREEWeb.Telemetry,
12-
LiveNxIREE.Repo,
12+
# LiveNxIREE.Repo,
1313
{DNSCluster, query: Application.get_env(:live_nx_iree, :dns_cluster_query) || :ignore},
1414
{Phoenix.PubSub, name: LiveNxIREE.PubSub},
1515
# Start the Finch HTTP client for sending emails

liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@ defmodule LiveNxIREEWeb.HomeLive do
44

55
@impl true
66
def mount(_params, _session, socket) do
7+
# register the liveview with presence.
8+
# This will make it so that we can interact with the liveview
9+
# from outside of the liveview itself.
10+
11+
# this can probably be achieved by just changing assigns from within a handle_info call, somehow
12+
13+
# need to figure out how to push and pull events from LVN. otherwise, we'll fallback
14+
# to a long-polling approach from within the liveview itself using a jobqueue via pubsub.
15+
16+
dbg(self())
17+
18+
socket = assign(socket, bytecode: nil, function_signature: nil)
19+
720
{:ok, socket}
821
end
922

@@ -17,4 +30,71 @@ defmodule LiveNxIREEWeb.HomeLive do
1730
|> assign(:page_title, "Listing Contexts")
1831
|> assign(:home, nil)
1932
end
33+
34+
# def handle_info({:nx, :execute, function, input_templates}, socket) do
35+
# socket =
36+
# assign(
37+
# socket,
38+
# :bytecode,
39+
# Base.encode64(inspect({:nx, :execute, function, input_templates}))
40+
# )
41+
42+
# {:noreply, socket}
43+
# end
44+
45+
@impl true
46+
def handle_info({:nx, :execute, function, input_templates, target_device}, socket) do
47+
fun =
48+
case function do
49+
{m, f, a} ->
50+
Function.capture(m, f, a)
51+
52+
_ ->
53+
raise """
54+
Expected a tuple of module, function, and arguments, but got: #{inspect(function)}
55+
"""
56+
end
57+
58+
backend_flag =
59+
case target_device do
60+
:metal -> "--iree-hal-target-backends=metal-spirv"
61+
:cpu -> "--iree-hal-target-backends=llvm-cpu"
62+
end
63+
64+
compiler_flags = [
65+
backend_flag,
66+
"--iree-input-type=stablehlo_xla",
67+
"--iree-execution-model=async-internal"
68+
]
69+
70+
{:ok, %{bytecode: %NxIREE.Module{bytecode: bytecode}, output_container: output_container}} =
71+
NxIREE.Compiler.to_bytecode(fun, input_templates, iree_compiler_flags: compiler_flags)
72+
73+
socket =
74+
socket
75+
|> assign(:bytecode, Base.encode64(bytecode))
76+
|> assign(:output_container, output_container)
77+
|> assign(:function_signature, get_signature(function, input_templates, output_container))
78+
79+
{:noreply, socket}
80+
end
81+
82+
defp get_signature({mod, fun, _a}, input_templates, output_container) do
83+
"#{inspect(mod)}.#{fun}(#{to_flat_type(input_templates)}) -> #{to_flat_type(output_container)}"
84+
end
85+
86+
defp to_flat_type(container) do
87+
List.wrap(container)
88+
|> Nx.Defn.Composite.flatten_list()
89+
|> Enum.map(fn t ->
90+
type_as_string(t) <> "x" <> Enum.join(Tuple.to_list(Nx.shape(t)), "x")
91+
end)
92+
|> Enum.join(", ")
93+
end
94+
95+
defp type_as_string(tensor) do
96+
{t, s} = Nx.type(tensor)
97+
98+
"#{t}#{s}"
99+
end
20100
end
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
<Text>Hello, World</Text>
1+
<NxFunction signature={@function_signature} bytecode={@bytecode}/>

liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@
88
import SwiftUI
99
import LiveViewNative
1010

11+
@LiveElement
1112
struct NxFunctionView<Root: RootRegistry>: View {
13+
@_documentation(visibility: public)
14+
@LiveAttribute("bytecode") private var bytecode: String? = nil
15+
@LiveAttribute("signature") private var signature: String? = nil
16+
1217
var body: some View {
13-
Text("NxFunction Component")
14-
.padding()
18+
if signature != nil {
19+
Text(signature!)
20+
.padding()
21+
}
1522
}
1623
}
17-

0 commit comments

Comments
 (0)