Skip to content

Commit 712bcc5

Browse files
Merge branch 'main' into scan_support
2 parents 843b3d2 + 9eaea4a commit 712bcc5

File tree

409 files changed

+17802
-3286
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

409 files changed

+17802
-3286
lines changed

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ jobs:
381381
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
382382
setup_script_args="--target-toolchain zephyr"
383383
toolchain_prefix=arm-zephyr-eabi-
384-
threshold="135768" # 136 KiB
384+
threshold="136000" # 136 KiB
385385
toolchain_cmake=examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake
386386
else
387387
echo "Fail unsupport OS selection ${{ matrix.os }}"

backends/aoti/aoti_backend.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import typing
1010
from abc import ABC, abstractmethod
1111
from enum import Enum
12-
from typing import Any, Dict, List, Optional, Set
12+
from typing import Any, Dict, List, Set
1313

1414
import torch
1515
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
@@ -70,10 +70,15 @@ def get_aoti_compile_options(
7070

7171
@classmethod
7272
@abstractmethod
73-
def get_custom_passes(cls) -> List[typing.Any]:
73+
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
7474
"""Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition."""
7575
pass
7676

77+
@classmethod
78+
def get_extra_aoti_compile_context_manager(cls):
79+
"""Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager."""
80+
return contextlib.nullcontext()
81+
7782
@classmethod
7883
@contextlib.contextmanager
7984
def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
@@ -91,39 +96,24 @@ def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]
9196
)
9297

9398
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
94-
self,
95-
kernel: str,
96-
args: list[str],
97-
device: str,
98-
*,
99-
debug_args: Optional[list[str]] = None,
100-
debug_handle: Optional[int] = None,
101-
):
99+
self, kernel: str, *args: Any, **kwargs: Any
100+
) -> None:
102101
if kernel not in supported_kernels:
103102
missing_fallback_kernels.add(kernel)
104103

105-
original_generate_c_shim_extern_kernel_call(
106-
self,
107-
kernel,
108-
args,
109-
device,
110-
debug_args=debug_args,
111-
debug_handle=debug_handle,
104+
return original_generate_c_shim_extern_kernel_call(
105+
self, kernel, *args, **kwargs
112106
)
113107

114108
def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
115-
self,
116-
op_overload,
117-
raw_args,
118-
output_args,
119-
raw_outputs,
120-
):
109+
self, op_overload: Any, *args: Any, **kwargs: Any
110+
) -> None:
121111
kernel_name = getattr(op_overload, "_name", str(op_overload))
122112
if kernel_name not in supported_kernels:
123113
missing_fallback_kernels.add(kernel_name)
124114

125-
original_generate_fallback_kernel_with_runtime_lookup_aot(
126-
self, op_overload, raw_args, output_args, raw_outputs
115+
return original_generate_fallback_kernel_with_runtime_lookup_aot(
116+
self, op_overload, *args, **kwargs
127117
)
128118

129119
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
@@ -164,7 +154,7 @@ def preprocess(
164154
ReplaceViewCopyWithViewPass()(device_edge_program.graph_module)
165155

166156
# Apply custom backend-specific passes
167-
custom_passes = cls.get_custom_passes()
157+
custom_passes = cls.get_custom_passes(compile_specs)
168158
for custom_pass in custom_passes:
169159
custom_pass(device_edge_program.graph_module)
170160

@@ -189,7 +179,7 @@ def preprocess(
189179
# Compile with fallback kernel collection
190180
with cls.collect_unsupported_fallback_kernels(
191181
missing_fallback_kernels
192-
), torch.no_grad():
182+
), torch.no_grad(), cls.get_extra_aoti_compile_context_manager():
193183
paths = torch._inductor.aot_compile(
194184
edge_program_module, tuple(user_input_placeholders), options=options
195185
)

backends/apple/coreml/runtime/delegate/ETCoreMLStrings.mm

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -101,39 +101,50 @@ + (NSString *)debugSymbolToHandlesKeyName {
101101
}
102102

103103
+ (nullable NSString *)assetsDirectoryPath {
104-
static dispatch_once_t onceToken;
105-
static NSString *result = nil;
106-
dispatch_once(&onceToken, ^{
107-
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES);
108-
if (paths.count > 0) {
109-
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
110-
}
111-
});
112-
113-
return result;
104+
#if defined(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH)
105+
return @(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH);
106+
#else
107+
static dispatch_once_t onceToken;
108+
static NSString *result = nil;
109+
dispatch_once(&onceToken, ^{
110+
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES);
111+
if (paths.count > 0) {
112+
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
113+
}
114+
});
115+
116+
return result;
117+
#endif
114118
}
115119

116120
+ (nullable NSString *)trashDirectoryPath {
117-
static dispatch_once_t onceToken;
118-
static NSString *result = nil;
119-
dispatch_once(&onceToken, ^{
120-
result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName];
121-
});
122-
123-
return result;
121+
#if defined(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH)
122+
return @(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH);
123+
#else
124+
static dispatch_once_t onceToken;
125+
static NSString *result = nil;
126+
dispatch_once(&onceToken, ^{
127+
result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName];
128+
});
129+
130+
return result;
131+
#endif
124132
}
125133

126134
+ (nullable NSString *)databaseDirectoryPath {
127-
static dispatch_once_t onceToken;
128-
static NSString *result = nil;
129-
dispatch_once(&onceToken, ^{
130-
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES);
131-
if (paths.count > 0) {
132-
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
133-
}
134-
});
135-
136-
return result;
135+
#if defined(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH)
136+
return @(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH);
137+
#else
138+
static dispatch_once_t onceToken;
139+
static NSString *result = nil;
140+
dispatch_once(&onceToken, ^{
141+
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES);
142+
if (paths.count > 0) {
143+
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
144+
}
145+
});
146+
return result;
147+
#endif
137148
}
138149

139150

backends/apple/metal/metal_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
4242
return {}
4343

4444
@classmethod
45-
def get_custom_passes(cls) -> List[typing.Any]:
45+
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
4646
"""Return Metal-specific passes (currently none)"""
4747
return []
4848

backends/apple/metal/runtime/shims/et_metal.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ class ETMetalKernelFunction {
181181
void startEncoding();
182182
void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor);
183183
void setArg(unsigned idx, int64_t val);
184+
void setArg(unsigned idx, uint32_t val);
185+
void setArg(unsigned idx, float val);
186+
void setArg(unsigned idx, bool val);
187+
void setArg(unsigned idx, const void* data, size_t size);
188+
189+
// Helper for Metal uint3 struct
190+
void setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z);
184191

185192
void dispatchSingle(uint64_t length);
186193
void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size);
@@ -191,6 +198,15 @@ class ETMetalKernelFunction {
191198
const uint64_t* group_size,
192199
size_t group_size_size);
193200

201+
// Dispatch with explicit threadgroup count (not thread count)
202+
void dispatchThreadgroups(
203+
uint64_t gridX,
204+
uint64_t gridY,
205+
uint64_t gridZ,
206+
uint64_t threadsX,
207+
uint64_t threadsY,
208+
uint64_t threadsZ);
209+
194210
void runCommandBlock(std::function<void(void)> f);
195211

196212
private:

backends/apple/metal/runtime/shims/et_metal.mm

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
1111
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
1212
#import <Foundation/Foundation.h>
13+
#include <simd/simd.h>
1314
#include <executorch/runtime/platform/log.h>
1415
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1516
#include <executorch/backends/apple/metal/runtime/shims/et_metal.h>
@@ -377,6 +378,58 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
377378
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set int64_t value %lld at index %u", val, idx);
378379
}
379380

381+
void ETMetalKernelFunction::setArg(unsigned idx, uint32_t val) {
382+
if (!encoder_) {
383+
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
384+
return;
385+
}
386+
387+
[encoder_ setBytes:&val length:sizeof(uint32_t) atIndex:idx];
388+
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set uint32_t value %u at index %u", val, idx);
389+
}
390+
391+
void ETMetalKernelFunction::setArg(unsigned idx, float val) {
392+
if (!encoder_) {
393+
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
394+
return;
395+
}
396+
397+
[encoder_ setBytes:&val length:sizeof(float) atIndex:idx];
398+
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set float value %f at index %u", val, idx);
399+
}
400+
401+
void ETMetalKernelFunction::setArg(unsigned idx, bool val) {
402+
if (!encoder_) {
403+
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
404+
return;
405+
}
406+
407+
[encoder_ setBytes:&val length:sizeof(bool) atIndex:idx];
408+
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bool value %s at index %u", val ? "true" : "false", idx);
409+
}
410+
411+
void ETMetalKernelFunction::setArg(unsigned idx, const void* data, size_t size) {
412+
if (!encoder_) {
413+
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
414+
return;
415+
}
416+
417+
[encoder_ setBytes:data length:size atIndex:idx];
418+
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bytes at index %u (size: %zu)", idx, size);
419+
}
420+
421+
void ETMetalKernelFunction::setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z) {
422+
if (!encoder_) {
423+
ET_LOG(Error, "ETMetalKernelFunction::setArgUint3: No active encoder");
424+
return;
425+
}
426+
427+
// Use SIMD library's uint3 type which matches Metal shader's uint3 layout
428+
simd_uint3 val = {x, y, z};
429+
[encoder_ setBytes:&val length:sizeof(simd_uint3) atIndex:idx];
430+
ET_LOG(Debug, "ETMetalKernelFunction::setArgUint3: Set uint3{%u, %u, %u} at index %u", x, y, z, idx);
431+
}
432+
380433
void ETMetalKernelFunction::dispatchSingle(uint64_t length) {
381434
if (!encoder_) {
382435
ET_LOG(Error, "ETMetalKernelFunction::dispatchSingle: No active encoder");
@@ -502,6 +555,40 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
502555

503556
}
504557

558+
void ETMetalKernelFunction::dispatchThreadgroups(uint64_t gridX, uint64_t gridY, uint64_t gridZ,
559+
uint64_t threadsX, uint64_t threadsY, uint64_t threadsZ) {
560+
if (!encoder_) {
561+
ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No active encoder");
562+
return;
563+
}
564+
565+
if (!cps_) {
566+
ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No compute pipeline state");
567+
return;
568+
}
569+
570+
// Calculate total threads per threadgroup
571+
uint64_t totalThreads = threadsX * threadsY * threadsZ;
572+
573+
const auto maxThreadsPerGroup = static_cast<uint64_t>([cps_ maxTotalThreadsPerThreadgroup]);
574+
575+
// Validate total thread count
576+
if (totalThreads > maxThreadsPerGroup) {
577+
ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: Requested %llu total threads per threadgroup exceeds device maximum of %llu",
578+
(unsigned long long)totalThreads, (unsigned long long)maxThreadsPerGroup);
579+
return;
580+
}
581+
582+
MTLSize threadgroupsPerGrid = MTLSizeMake(gridX, gridY, gridZ);
583+
MTLSize threadsPerThreadgroup = MTLSizeMake(threadsX, threadsY, threadsZ);
584+
585+
[encoder_ dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
586+
587+
ET_LOG(Debug, "ETMetalKernelFunction::dispatchThreadgroups: Dispatched grid [%llu, %llu, %llu] with threadgroup [%llu, %llu, %llu]",
588+
(unsigned long long)gridX, (unsigned long long)gridY, (unsigned long long)gridZ,
589+
(unsigned long long)threadsX, (unsigned long long)threadsY, (unsigned long long)threadsZ);
590+
}
591+
505592
void ETMetalKernelFunction::runCommandBlock(std::function<void(void)> f) {
506593
// Use dispatch_sync with the stream's serial queue for thread safety and synchronization
507594
// This matches PyTorch's approach: dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ...)

0 commit comments

Comments
 (0)