Skip to content

Commit 0b4fcb4

Browse files
committed
Add support for Kotlin functions
This commit adds support for proper Kotlin functions handling by adapting kotlin.jvm.functions.Function1 to java.util.function.Function and kotlin.jvm.functions.Function2 to java.util.function.BiFunction. It also removes the dependency on Spring Cloud Function and net.jodah:typetools which are replaced by leveraging Spring Framework ResolvableType capabilities.
1 parent 3e9a3fd commit 0b4fcb4

File tree

9 files changed

+527
-115
lines changed

9 files changed

+527
-115
lines changed

spring-ai-core/pom.xml

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,6 @@
5555
<version>${jsonschema.version}</version>
5656
</dependency>
5757

58-
<dependency>
59-
<groupId>org.springframework.cloud</groupId>
60-
<artifactId>spring-cloud-function-context</artifactId>
61-
<version>${spring-cloud-function-context.version}</version>
62-
<exclusions>
63-
<exclusion>
64-
<groupId>org.springframework.boot</groupId>
65-
<artifactId>spring-boot-autoconfigure</artifactId>
66-
</exclusion>
67-
</exclusions>
68-
</dependency>
69-
7058
<!-- production dependencies -->
7159
<dependency>
7260
<groupId>org.antlr</groupId>
@@ -139,6 +127,13 @@
139127
<version>${jackson.version}</version>
140128
</dependency>
141129

130+
<dependency>
131+
<groupId>org.jetbrains.kotlin</groupId>
132+
<artifactId>kotlin-stdlib</artifactId>
133+
<version>${kotlin.version}</version>
134+
<optional>true</optional>
135+
</dependency>
136+
142137
<!-- test dependencies -->
143138
<dependency>
144139
<groupId>org.springframework.boot</groupId>
@@ -147,16 +142,16 @@
147142
</dependency>
148143

149144
<dependency>
150-
<groupId>org.jetbrains.kotlin</groupId>
151-
<artifactId>kotlin-stdlib</artifactId>
152-
<version>${kotlin.version}</version>
145+
<groupId>com.fasterxml.jackson.module</groupId>
146+
<artifactId>jackson-module-kotlin</artifactId>
147+
<version>${jackson.version}</version>
153148
<scope>test</scope>
154149
</dependency>
155150

156151
<dependency>
157-
<groupId>com.fasterxml.jackson.module</groupId>
158-
<artifactId>jackson-module-kotlin</artifactId>
159-
<version>${jackson.version}</version>
152+
<groupId>io.mockk</groupId>
153+
<artifactId>mockk-jvm</artifactId>
154+
<version>1.13.13</version>
160155
<scope>test</scope>
161156
</dependency>
162157

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,22 @@
1616

1717
package org.springframework.ai.model.function;
1818

19-
import java.lang.reflect.Type;
2019
import java.util.function.BiFunction;
2120
import java.util.function.Function;
2221

2322
import com.fasterxml.jackson.annotation.JsonClassDescription;
23+
import kotlin.jvm.functions.Function1;
24+
import kotlin.jvm.functions.Function2;
2425

2526
import org.springframework.ai.chat.model.ToolContext;
2627
import org.springframework.beans.BeansException;
27-
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
28-
import org.springframework.cloud.function.context.config.FunctionContextUtils;
28+
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
29+
import org.springframework.beans.factory.config.BeanDefinition;
2930
import org.springframework.context.ApplicationContext;
3031
import org.springframework.context.ApplicationContextAware;
3132
import org.springframework.context.annotation.Description;
3233
import org.springframework.context.support.GenericApplicationContext;
34+
import org.springframework.core.ResolvableType;
3335
import org.springframework.lang.NonNull;
3436
import org.springframework.lang.Nullable;
3537
import org.springframework.util.StringUtils;
@@ -49,6 +51,7 @@
4951
*
5052
* @author Christian Tzolov
5153
* @author Christopher Smith
54+
* @author Sebastien Deleuze
5255
*/
5356
public class FunctionCallbackContext implements ApplicationContextAware {
5457

@@ -68,23 +71,19 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
6871
@SuppressWarnings({ "rawtypes", "unchecked" })
6972
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {
7073

71-
Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName);
72-
73-
if (beanType == null) {
74-
throw new IllegalArgumentException(
75-
"Functional bean with name: " + beanName + " does not exist in the context.");
74+
BeanDefinition beanDefinition;
75+
try {
76+
beanDefinition = this.applicationContext.getBeanDefinition(beanName);
7677
}
77-
78-
if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))
79-
&& !BiFunction.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) {
78+
catch (NoSuchBeanDefinitionException ex) {
8079
throw new IllegalArgumentException(
81-
"Function call Bean must be of type Function or BiFunction. Found: " + beanType.getTypeName());
80+
"Functional bean with name " + beanName + " does not exist in the context.");
8281
}
8382

84-
Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0);
83+
ResolvableType functionType = beanDefinition.getResolvableType();
84+
ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType.getType(), 0);
8585

86-
Class<?> functionInputClass = FunctionTypeUtils.getRawType(functionInputType);
87-
String functionName = beanName;
86+
Class<?> functionInputClass = functionInputType.toClass();
8887
String functionDescription = defaultDescription;
8988

9089
if (!StringUtils.hasText(functionDescription)) {
@@ -114,24 +113,40 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
114113

115114
Object bean = this.applicationContext.getBean(beanName);
116115

117-
if (bean instanceof Function<?, ?> function) {
116+
if (KotlinDelegate.isKotlinFunction(functionType.toClass())) {
117+
return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinFunction(bean))
118+
.withName(beanName)
119+
.withSchemaType(this.schemaType)
120+
.withDescription(functionDescription)
121+
.withInputType(functionInputClass)
122+
.build();
123+
}
124+
else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
125+
return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinBiFunction(bean))
126+
.withName(beanName)
127+
.withSchemaType(this.schemaType)
128+
.withDescription(functionDescription)
129+
.withInputType(functionInputClass)
130+
.build();
131+
}
132+
else if (bean instanceof Function<?, ?> function) {
118133
return FunctionCallbackWrapper.builder(function)
119-
.withName(functionName)
134+
.withName(beanName)
120135
.withSchemaType(this.schemaType)
121136
.withDescription(functionDescription)
122137
.withInputType(functionInputClass)
123138
.build();
124139
}
125-
else if (bean instanceof BiFunction<?, ?, ?> biFunction) {
126-
return FunctionCallbackWrapper.builder((BiFunction<?, ToolContext, ?>) biFunction)
127-
.withName(functionName)
140+
else if (bean instanceof BiFunction<?, ?, ?>) {
141+
return FunctionCallbackWrapper.builder((BiFunction<?, ToolContext, ?>) bean)
142+
.withName(beanName)
128143
.withSchemaType(this.schemaType)
129144
.withDescription(functionDescription)
130145
.withInputType(functionInputClass)
131146
.build();
132147
}
133148
else {
134-
throw new IllegalArgumentException("Bean must be of type Function");
149+
throw new IllegalStateException();
135150
}
136151
}
137152

@@ -141,4 +156,26 @@ public enum SchemaType {
141156

142157
}
143158

159+
private static class KotlinDelegate {
160+
161+
public static boolean isKotlinFunction(Class<?> clazz) {
162+
return Function1.class.isAssignableFrom(clazz);
163+
}
164+
165+
@SuppressWarnings("unchecked")
166+
public static Function<?, ?> wrapKotlinFunction(Object function) {
167+
return t -> ((Function1<Object, Object>) function).invoke(t);
168+
}
169+
170+
public static boolean isKotlinBiFunction(Class<?> clazz) {
171+
return Function2.class.isAssignableFrom(clazz);
172+
}
173+
174+
@SuppressWarnings("unchecked")
175+
public static BiFunction<?, ToolContext, ?> wrapKotlinBiFunction(Object function) {
176+
return (t, u) -> ((Function2<Object, ToolContext, Object>) function).invoke(t, u);
177+
}
178+
179+
}
180+
144181
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java

Lines changed: 52 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,22 @@
1616

1717
package org.springframework.ai.model.function;
1818

19-
import java.lang.reflect.GenericArrayType;
20-
import java.lang.reflect.ParameterizedType;
2119
import java.lang.reflect.Type;
2220
import java.util.function.BiFunction;
2321
import java.util.function.Function;
2422

25-
import net.jodah.typetools.TypeResolver;
23+
import kotlin.jvm.functions.Function1;
24+
import kotlin.jvm.functions.Function2;
2625

27-
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
26+
import org.springframework.core.KotlinDetector;
27+
import org.springframework.core.ResolvableType;
2828

2929
/**
3030
* A utility class that provides methods for resolving types and classes related to
3131
* functions.
3232
*
3333
* @author Christian Tzolov
34+
* @author Sebastien Dekeuze
3435
*/
3536
public abstract class TypeResolverHelper {
3637

@@ -68,12 +69,9 @@ public static Class<?> getFunctionOutputClass(Class<? extends Function<?, ?>> fu
6869
* @return The class of the specified function argument.
6970
*/
7071
public static Class<?> getFunctionArgumentClass(Class<? extends Function<?, ?>> functionClass, int argumentIndex) {
71-
Type type = TypeResolver.reify(Function.class, functionClass);
72-
73-
var argumentType = type instanceof ParameterizedType
74-
? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class;
75-
76-
return toRawClass(argumentType);
72+
ResolvableType resolvableType = ResolvableType.forClass(functionClass).as(Function.class);
73+
return (resolvableType == ResolvableType.NONE ? Object.class
74+
: resolvableType.getGeneric(argumentIndex).toClass());
7775
}
7876

7977
/**
@@ -84,80 +82,65 @@ public static Class<?> getFunctionArgumentClass(Class<? extends Function<?, ?>>
8482
*/
8583
public static Class<?> getBiFunctionArgumentClass(Class<? extends BiFunction<?, ?, ?>> biFunctionClass,
8684
int argumentIndex) {
87-
Type type = TypeResolver.reify(BiFunction.class, biFunctionClass);
88-
89-
Type argumentType = type instanceof ParameterizedType
90-
? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class;
91-
92-
return toRawClass(argumentType);
93-
}
94-
95-
/**
96-
* Returns the input type of a given function class.
97-
* @param functionClass The class of the function.
98-
* @return The input type of the function.
99-
*/
100-
public static Type getFunctionInputType(Class<? extends Function<?, ?>> functionClass) {
101-
return getFunctionArgumentType(functionClass, 0);
102-
}
103-
104-
/**
105-
* Retrieves the output type of a given function class.
106-
* @param functionClass The function class.
107-
* @return The output type of the function.
108-
*/
109-
public static Type getFunctionOutputType(Class<? extends Function<?, ?>> functionClass) {
110-
return getFunctionArgumentType(functionClass, 1);
85+
ResolvableType resolvableType = ResolvableType.forClass(biFunctionClass).as(BiFunction.class);
86+
return (resolvableType == ResolvableType.NONE ? Object.class
87+
: resolvableType.getGeneric(argumentIndex).toClass());
11188
}
11289

11390
/**
11491
* Retrieves the type of a specific argument in a given function class.
115-
* @param functionClass The function class.
116-
* @param argumentIndex The index of the argument whose type should be retrieved.
117-
* @return The type of the specified function argument.
118-
*/
119-
public static Type getFunctionArgumentType(Class<? extends Function<?, ?>> functionClass, int argumentIndex) {
120-
Type functionType = TypeResolver.reify(Function.class, functionClass);
121-
return getFunctionArgumentType(functionType, argumentIndex);
122-
}
123-
124-
/**
125-
* Retrieves the type of a specific argument in a given function type.
12692
* @param functionType The function type.
12793
* @param argumentIndex The index of the argument whose type should be retrieved.
12894
* @return The type of the specified function argument.
95+
* @throws IllegalArgumentException if functionType is not a supported type
12996
*/
130-
public static Type getFunctionArgumentType(Type functionType, int argumentIndex) {
131-
132-
// Resolves: https://github.com/spring-projects/spring-ai/issues/726
133-
if (!(functionType instanceof ParameterizedType)) {
134-
Class<?> functionalClass = FunctionTypeUtils.getRawType(functionType);
135-
// Resolves: https://github.com/spring-projects/spring-ai/issues/1576
136-
if (BiFunction.class.isAssignableFrom(functionalClass)) {
137-
functionType = TypeResolver.reify(BiFunction.class, (Class<BiFunction<?, ?, ?>>) functionalClass);
97+
public static ResolvableType getFunctionArgumentType(Type functionType, int argumentIndex) {
98+
99+
ResolvableType resolvableType = ResolvableType.forType(functionType);
100+
Class<?> resolvableClass = resolvableType.toClass();
101+
ResolvableType functionArgumentResolvableType = ResolvableType.NONE;
102+
103+
if (Function.class.isAssignableFrom(resolvableClass)) {
104+
functionArgumentResolvableType = resolvableType.as(Function.class);
105+
}
106+
else if (BiFunction.class.isAssignableFrom(resolvableClass)) {
107+
functionArgumentResolvableType = resolvableType.as(BiFunction.class);
108+
}
109+
else if (KotlinDetector.isKotlinPresent()) {
110+
if (KotlinDelegate.isKotlinFunction(resolvableClass)) {
111+
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(resolvableType);
138112
}
139-
else {
140-
functionType = FunctionTypeUtils.discoverFunctionTypeFromClass(functionalClass);
113+
else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) {
114+
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinBiFunctionType(resolvableType);
141115
}
142116
}
143117

144-
var argumentType = functionType instanceof ParameterizedType
145-
? ((ParameterizedType) functionType).getActualTypeArguments()[argumentIndex] : Object.class;
118+
if (functionArgumentResolvableType == ResolvableType.NONE) {
119+
throw new IllegalArgumentException(
120+
"Type must be a Function, BiFunction, Function1 or Function2. Found: " + resolvableType);
121+
}
146122

147-
return argumentType;
123+
return functionArgumentResolvableType.getGeneric(argumentIndex);
148124
}
149125

150-
/**
151-
* Effectively converts {@link Type} which could be {@link ParameterizedType} to raw
152-
* Class (no generics).
153-
* @param type actual {@link Type} instance
154-
* @return instance of {@link Class} as raw representation of the provided
155-
* {@link Type}
156-
*/
157-
public static Class<?> toRawClass(Type type) {
158-
return type != null
159-
? TypeResolver.resolveRawClass(type instanceof GenericArrayType ? type : TypeResolver.reify(type), null)
160-
: null;
126+
private static class KotlinDelegate {
127+
128+
public static boolean isKotlinFunction(Class<?> clazz) {
129+
return Function1.class.isAssignableFrom(clazz);
130+
}
131+
132+
public static ResolvableType adaptToKotlinFunctionType(ResolvableType resolvableType) {
133+
return resolvableType.as(Function1.class);
134+
}
135+
136+
public static boolean isKotlinBiFunction(Class<?> clazz) {
137+
return Function2.class.isAssignableFrom(clazz);
138+
}
139+
140+
public static ResolvableType adaptToKotlinBiFunctionType(ResolvableType resolvableType) {
141+
return resolvableType.as(Function2.class);
142+
}
143+
161144
}
162145

163146
}

spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
import org.junit.jupiter.params.provider.ValueSource;
2424

2525
import org.springframework.beans.factory.annotation.Autowired;
26+
import org.springframework.beans.factory.config.BeanDefinition;
2627
import org.springframework.boot.SpringBootConfiguration;
2728
import org.springframework.boot.test.context.SpringBootTest;
28-
import org.springframework.cloud.function.context.config.FunctionContextUtils;
2929
import org.springframework.context.annotation.Bean;
3030
import org.springframework.context.support.GenericApplicationContext;
3131

@@ -41,12 +41,11 @@ class TypeResolverHelperIT {
4141
@ValueSource(strings = { "weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction" })
4242
void beanInputTypeResolutionTest(String beanName) {
4343
assertThat(this.applicationContext).isNotNull();
44-
Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName);
45-
assertThat(beanType).isNotNull();
46-
Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0);
47-
assertThat(functionInputType).isNotNull();
48-
assertThat(functionInputType.getTypeName()).isEqualTo(WeatherRequest.class.getName());
49-
44+
BeanDefinition beanDefinition = this.applicationContext.getBeanDefinition(beanName);
45+
Type beanType = beanDefinition.getResolvableType().getType();
46+
Class<?> functionInputClass = TypeResolverHelper.getFunctionArgumentType(beanType, 0).getRawClass();
47+
assertThat(functionInputClass).isNotNull();
48+
assertThat(functionInputClass.getTypeName()).isEqualTo(WeatherRequest.class.getName());
5049
}
5150

5251
public record WeatherRequest(String city) {

0 commit comments

Comments
 (0)