Skip to content

Commit 413d754

Browse files
committed
fix: update iree
1 parent 8d2846e commit 413d754

File tree

15 files changed

+989
-1194
lines changed

15 files changed

+989
-1194
lines changed

.github/workflows/embedded_devices.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
MIX_ENV: prod
1717
NX_IREE_PREFER_PRECOMPILED: false
1818
NX_IREE_SOURCE_DIR: ./build-cache/iree
19-
IREE_GIT_REV: candidate-20240604.914
19+
IREE_GIT_REV: candidate-20240822.993
2020
strategy:
2121
fail-fast: true
2222
matrix:

.github/workflows/precompiled_nif.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
MIX_ENV: prod
1717
NX_IREE_PREFER_PRECOMPILED: false
1818
NX_IREE_SOURCE_DIR: ./build-cache/iree
19-
IREE_GIT_REV: candidate-20240604.914
19+
IREE_GIT_REV: candidate-20240822.993
2020
BUILD_IREE_RUNTIME: false
2121
strategy:
2222
fail-fast: false
@@ -83,7 +83,7 @@ jobs:
8383
MIX_ENV: prod
8484
NX_IREE_PREFER_PRECOMPILED: false
8585
NX_IREE_SOURCE_DIR: ./build-cache/iree
86-
IREE_GIT_REV: candidate-20240604.914
86+
IREE_GIT_REV: candidate-20240822.993
8787
ImageOS: ubuntu22
8888
LANG: en_US.UTF-8
8989
LANGUAGE: en_US:en

compiler.exs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Nx.Defn.default_options(compiler: NxIREE.Compiler, iree_compiler_flags: flags, i
1818

1919
f = Nx.Defn.compile(fun, args)
2020

21+
Nx.default_backend(NxIREE.Tensor)
2122
arg0 = Nx.tensor([1.0, 2.0, 3.0, 4.0])
2223
arg1 = Nx.tensor([1, -1, 1, -1])
2324
f.(arg0, arg1) |> dbg()

lib/nx_iree.ex

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ defmodule NxIREE do
6969

7070
[driver_name, _] = String.split(device, "://", parts: 2)
7171

72-
input_refs = Enum.map(inputs, &NxIREE.VM.allocate_buffer(&1, device_ref))
72+
input_refs =
73+
Enum.map(inputs, fn
74+
%Nx.Tensor{data: %NxIREE.Tensor{ref: ref}} -> ref
75+
t -> NxIREE.VM.allocate_buffer(t, device_ref)
76+
end)
7377

7478
instance_ref = NxIREE.VM.get_instance()
7579

liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ defmodule LiveNxIREEWeb.HomeLive do
6565
{backend_flag, runtime_device} =
6666
case target_device do
6767
:metal ->
68-
{"--iree-hal-target-backends=metal-spirv", "metal://default"}
68+
{"--iree-hal-target-backends=metal-spirv", "metal://"}
6969

7070
:cpu ->
71-
{"--iree-hal-target-backends=llvm-cpu", "local-sync://default"}
71+
{"--iree-hal-target-backends=llvm-cpu", "local-sync://"}
7272
end
7373

7474
compiler_flags = [
@@ -104,7 +104,7 @@ defmodule LiveNxIREEWeb.HomeLive do
104104

105105
{:ok, serialized} = NxIREE.Native.serialize_tensor(tensor.data.ref)
106106

107-
serialized
107+
Base.encode64(serialized)
108108
end)
109109
end
110110

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
1-
<NxFunction on-execution="nx-executed" signature={@function_signature} bytecode={@bytecode} device={@device} inputs={@inputs} num-outputs={@num_outputs} />
1+
<NxFunction
2+
on-execution="nx-executed"
3+
signature={@function_signature}
4+
bytecode={@bytecode}
5+
device={@device}
6+
inputs={@inputs}
7+
num-outputs={@num_outputs} />

liveview_native/live_nx_iree/native/swiftui/LiveNxIREE.xcodeproj/project.pbxproj

Lines changed: 920 additions & 1122 deletions
Large diffs are not rendered by default.

liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/LiveNxIREE.swift

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,6 @@ import SwiftUI
77

88
@main
99
struct LiveNxIREE: App {
10-
init() {
11-
// Allocate memory for the pointers
12-
let vmInstance = UnsafeMutablePointer<iree_vm_instance_t>.allocate(capacity: 1)
13-
let driverRegistry = UnsafeMutablePointer<iree_hal_driver_registry_t>.allocate(capacity: 1)
14-
let errorMessage = UnsafeMutablePointer<CChar>.allocate(capacity: 256)
15-
16-
// Call the initialization function
17-
let result = nx_iree_initialize(vmInstance, driverRegistry, errorMessage)
18-
19-
if result != 0 {
20-
// Handle the error
21-
let errorString = String(cString: errorMessage)
22-
print("Error initializing nx_iree: \(errorString)")
23-
} else {
24-
globalVmInstance = vmInstance
25-
globalDriverRegistry = driverRegistry
26-
print("nx_iree initialized successfully.")
27-
}
28-
}
29-
3010
var body: some Scene {
3111
WindowGroup {
3212
ContentView()

liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,24 @@ import LiveViewNative
99
import SwiftUI
1010

1111
class iree_vm_instance_t {}
12-
class iree_hal_driver_registry_t {}
1312
class iree_hal_device_t {}
1413

15-
var globalVmInstance: UnsafeMutablePointer<iree_vm_instance_t>?
16-
var globalDriverRegistry: UnsafeMutablePointer<iree_hal_driver_registry_t>?
17-
18-
@_silgen_name("nx_iree_initialize")
19-
func nx_iree_initialize(
20-
_ vm_instance: UnsafeMutablePointer<iree_vm_instance_t>,
21-
_ driver_registry: UnsafeMutablePointer<iree_hal_driver_registry_t>,
22-
_ error_message: UnsafeMutablePointer<CChar>) -> Int
14+
@_silgen_name("nx_iree_create_instance")
15+
func nx_iree_create_instance() -> UnsafePointer<iree_vm_instance_t>?
2316

2417
@_silgen_name("nx_iree_create_device")
25-
func nx_iree_create_device(
26-
_ driver_registry: UnsafeMutablePointer<iree_hal_driver_registry_t>,
27-
_ name: UnsafePointer<CChar>) -> UnsafeMutablePointer<iree_hal_device_t>
18+
func nx_iree_create_device(_ name: UnsafePointer<CChar>) -> UnsafePointer<iree_hal_device_t>?
2819

2920
@_silgen_name("nx_iree_call")
3021
func nx_iree_call(
31-
_ vm_instance: UnsafeMutablePointer<iree_vm_instance_t>,
32-
_ device: UnsafeMutablePointer<iree_hal_device_t>,
22+
_ vm_instance: UnsafePointer<iree_vm_instance_t>,
23+
_ device: UnsafePointer<iree_hal_device_t>,
3324
_ bytecode_size: UInt64,
3425
_ bytecode: UnsafePointer<CUnsignedChar>,
3526
_ num_inputs: UInt64,
3627
_ serialized_inputs: UnsafePointer<UnsafePointer<CChar>>,
3728
_ num_outputs: UInt64,
38-
_ serialized_outputs: UnsafePointer<UnsafePointer<CChar>>,
39-
_ error_message: UnsafeMutablePointer<CChar>) -> Int
29+
_ error_message: UnsafeMutablePointer<CChar>) -> UnsafePointer<UnsafePointer<CChar>>?
4030

4131

4232

0 commit comments

Comments
 (0)