Skip to content

Commit 39c9796

Browse files
committed
Fixes and improvmenets
1 parent a7f3444 commit 39c9796

File tree

2 files changed

+58
-74
lines changed

2 files changed

+58
-74
lines changed

build.rs

Lines changed: 57 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ extern crate semver;
44
use bindgen::callbacks::{ParseCallbacks, TypeKind};
55
use semver::Version;
66
use std::env;
7-
use std::path::{Path, PathBuf};
8-
use std::process::{Command, ExitStatus};
7+
use std::path::PathBuf;
8+
use std::process::Command;
99

1010
const WSL_PACKAGE_NAME: &str = "Microsoft.WSL.PluginApi";
11-
const LOCAL_NUGET_PATH: &str = "nuget_packages"; // Local folder to store NuGet packages
11+
const LOCAL_NUGET_FOLDER: &str = "nuget_packages";
12+
const WSL_PLUGIN_API_BINDGEN_OUTPUT_FILE_NAME: &str = "WSLPluginApi.rs";
1213

1314
#[derive(Debug, Default)]
1415
struct BindgenCallback {
@@ -17,52 +18,52 @@ struct BindgenCallback {
1718

1819
impl BindgenCallback {
1920
fn new(generate_hooks_fields_names: bool) -> Self {
20-
BindgenCallback {
21+
Self {
2122
generate_hooks_fields_name: generate_hooks_fields_names,
2223
}
2324
}
2425
}
2526

2627
impl ParseCallbacks for BindgenCallback {
27-
fn add_derives(&self, _info: &bindgen::callbacks::DeriveInfo<'_>) -> Vec<String> {
28-
if _info.kind == TypeKind::Struct && _info.name == "WSLVersion" {
29-
vec![
30-
"Eq".into(),
31-
"PartialEq".into(),
32-
"Ord".into(),
33-
"PartialOrd".into(),
34-
"Hash".into(),
35-
]
36-
} else if _info.kind == TypeKind::Struct
37-
&& _info.name.contains("PluginHooks")
38-
&& self.generate_hooks_fields_name
39-
{
40-
vec!["FieldNamesAsSlice".into()]
41-
} else {
42-
vec![]
28+
fn add_derives(&self, info: &bindgen::callbacks::DeriveInfo<'_>) -> Vec<String> {
29+
let mut derives = Vec::new();
30+
31+
if info.kind == TypeKind::Struct {
32+
if info.name == "WSLVersion" {
33+
derives.extend(vec![
34+
"Eq".to_string(),
35+
"PartialEq".to_string(),
36+
"Ord".to_string(),
37+
"PartialOrd".to_string(),
38+
"Hash".to_string(),
39+
]);
40+
} else if info.name.contains("PluginHooks") && self.generate_hooks_fields_name {
41+
derives.push("FieldNamesAsSlice".to_string());
42+
}
4343
}
44+
45+
derives
4446
}
4547
}
4648

47-
// Function to ensure the NuGet package is installed in the local folder
49+
/// Ensures that the NuGet package is installed in the local folder.
4850
fn ensure_package_installed(
4951
package_name: &str,
5052
package_version: &str,
51-
output_dir: &str,
52-
) -> Result<ExitStatus, Box<dyn std::error::Error>> {
53-
// Run the NuGet install command with -NonInteractive to avoid prompts
53+
) -> Result<PathBuf, Box<dyn std::error::Error>> {
54+
let out_dir: PathBuf = env::var("OUT_DIR")?.into();
55+
let package_dir = out_dir.join(LOCAL_NUGET_FOLDER);
5456
let status = Command::new("nuget")
5557
.args([
5658
"install",
5759
package_name,
5860
"-Version",
5961
package_version,
6062
"-OutputDirectory",
61-
output_dir, // Local folder to install the NuGet package
62-
"-NonInteractive", // Ensures the command runs without user interaction
63+
package_dir.to_str().unwrap(),
64+
"-NonInteractive",
6365
])
64-
.status()
65-
.expect("Failed to execute nuget install command");
66+
.status()?;
6667

6768
if !status.success() {
6869
return Err(format!(
@@ -71,81 +72,64 @@ fn ensure_package_installed(
7172
)
7273
.into());
7374
}
74-
Ok(status)
75+
Ok(package_dir.join(format!("{}.{}", package_name, package_version)))
7576
}
7677

7778
fn main() -> Result<(), Box<dyn std::error::Error>> {
78-
// Extract the version of the package from the Cargo metadata
79-
let version_str = env!("CARGO_PKG_VERSION");
80-
let version = Version::parse(version_str).expect("Unable to parse the Cargo package version");
81-
let build_metadata = &version.build;
82-
8379
println!("cargo:rerun-if-changed=build.rs");
80+
81+
// Extract version from Cargo package metadata
82+
let version = Version::parse(env!("CARGO_PKG_VERSION"))?;
8483
println!("cargo:version={}", version);
85-
if !build_metadata.is_empty() {
86-
println!("cargo:build-metadata={}", build_metadata);
87-
}
8884

89-
let package_version = build_metadata.to_string();
85+
if !version.build.is_empty() {
86+
println!("cargo:build-metadata={}", version.build);
87+
}
9088

91-
// Ensure the NuGet package is installed in the specified local directory
92-
ensure_package_installed(WSL_PACKAGE_NAME, &package_version, LOCAL_NUGET_PATH)?;
89+
let package_version = version.build.to_string();
90+
let out_path: PathBuf = env::var("OUT_DIR")?.into();
9391

94-
// Construct the full path to the installed package in the local directory
95-
let package_path =
96-
Path::new(LOCAL_NUGET_PATH).join(format!("{:}.{:}", WSL_PACKAGE_NAME, package_version));
92+
// Ensure the NuGet package is installed
93+
let package_path = ensure_package_installed(WSL_PACKAGE_NAME, &package_version)?;
9794

98-
// Construct the path to the header file
99-
let header_file_path = package_path
100-
.join("build")
101-
.join("native")
102-
.join("include")
103-
.join("WslPluginApi.h");
95+
// Construct paths
96+
let header_file_path = package_path.join("build/native/include/WslPluginApi.h");
10497

105-
// Check if the header file exists
10698
if !header_file_path.exists() {
10799
return Err(format!("Header file does not exist: {:?}", header_file_path).into());
108100
}
109101

110102
println!("Using header file from: {:?}", header_file_path);
111103

112-
// Use bindgen to generate Rust bindings from the header file
113104
let hooks_fields_name_feature = env::var("CARGO_FEATURE_HOOKS_FIELD_NAMES").is_ok();
114105
let mut builder = bindgen::Builder::default()
115106
.header(header_file_path.to_str().unwrap())
116107
.raw_line("use windows::core::*;")
117108
.raw_line("use windows::Win32::Foundation::*;")
118109
.raw_line("use windows::Win32::Security::*;")
119110
.raw_line("use windows::Win32::Networking::WinSock::SOCKET;")
120-
.raw_line("#[allow(clippy::upper_case_acronyms)]")
121-
.raw_line("type LPCWSTR = PCWSTR;")
122-
.raw_line("#[allow(clippy::upper_case_acronyms)]")
123-
.raw_line("type LPCSTR = PCSTR;")
124-
.raw_line("#[allow(clippy::upper_case_acronyms)]")
125-
.raw_line("type DWORD = u32;");
126-
127-
if hooks_fields_name_feature {
128-
builder = builder.raw_line("use struct_field_names_as_array::FieldNamesAsSlice;");
129-
}
130-
131-
let api_header = builder
111+
.raw_line("#[allow(clippy::upper_case_acronyms)] type LPCWSTR = PCWSTR;")
112+
.raw_line("#[allow(clippy::upper_case_acronyms)] type LPCSTR = PCSTR;")
113+
.raw_line("#[allow(clippy::upper_case_acronyms)] type DWORD = u32;")
132114
.derive_debug(true)
133115
.derive_copy(true)
134116
.allowlist_item("WSL.*")
135117
.allowlist_item("Wsl.*")
136118
.clang_arg("-fparse-all-comments")
137119
.allowlist_recursively(false)
138120
.parse_callbacks(Box::new(BindgenCallback::new(hooks_fields_name_feature)))
139-
.generate_comments(true)
140-
.generate()
141-
.expect("Unable to generate wslplugins_sys");
142-
143-
// Write the generated bindings to the OUT_DIR
144-
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
145-
let out_file = out_dir.join("wslplugins_sys.rs");
146-
api_header
147-
.write_to_file(out_file)
148-
.expect("Couldn't write wslplugins_sys!");
121+
.generate_comments(true);
122+
123+
if hooks_fields_name_feature {
124+
builder = builder.raw_line("use struct_field_names_as_array::FieldNamesAsSlice;");
125+
}
126+
127+
// Generate Rust bindings
128+
let api_header = builder.generate()?;
129+
130+
// Write bindings to OUT_DIR
131+
let out_file = out_path.join(WSL_PLUGIN_API_BINDGEN_OUTPUT_FILE_NAME);
132+
api_header.write_to_file(&out_file)?;
149133

150134
Ok(())
151135
}

src/bindgen.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#![allow(non_upper_case_globals)]
22
#![allow(non_camel_case_types)]
33
#![allow(non_snake_case)]
4-
include!(concat!(env!("OUT_DIR"), "/wslplugins_sys.rs"));
4+
include!(concat!(env!("OUT_DIR"), "/WSLPluginApi.rs"));

0 commit comments

Comments
 (0)