Skip to content

Commit 9eff321

Browse files
committed
Merge branch 'release/2.0' into develop
2 parents 0d16e49 + 22454f5 commit 9eff321

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

src/main/java/co/cask/hydrator/plugin/spark/dynamic/ScalaSparkCompute.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
118118
Method method = getTransformMethod(interpreter.getClassLoader(), className);
119119

120120
// If the method takes DataFrame, make sure it has input schema
121-
if (method.getParameterTypes()[0].equals(DATAFRAME_TYPE) && stageConfigurer.getInputSchema() == null) {
121+
if (isDataFrame(method.getParameterTypes()[0]) && stageConfigurer.getInputSchema() == null) {
122122
throw new IllegalArgumentException("Missing input schema for transformation using DataFrame");
123123
}
124124

@@ -172,7 +172,7 @@ public void onEvent(SparkListenerEvent event) {
172172
interpreter.compile(generateSourceClass(className));
173173
method = getTransformMethod(interpreter.getClassLoader(), className);
174174

175-
isDataFrame = method.getParameterTypes()[0].equals(DATAFRAME_TYPE);
175+
isDataFrame = isDataFrame(method.getParameterTypes()[0]);
176176
takeContext = method.getParameterTypes().length == 2;
177177

178178
// Input schema shouldn't be null
@@ -302,9 +302,10 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
302302
}
303303

304304
Type[] parameterTypes = method.getGenericParameterTypes();
305+
boolean isDataFrame = isDataFrame(parameterTypes[0]);
305306

306307
// The first parameter should be of type RDD[StructuredRecord] if it takes RDD
307-
if (!parameterTypes[0].equals(DATAFRAME_TYPE)) {
308+
if (!isDataFrame) {
308309
validateRDDType(parameterTypes[0],
309310
"The first parameter of the 'transform' method should have type as 'RDD[StructuredRecord]'");
310311
}
@@ -317,8 +318,8 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
317318

318319
// The return type of the method must be RDD[StructuredRecord] if it takes RDD
319320
// Or it must be DataFrame if it takes DataFrame
320-
if (parameterTypes[0].equals(DATAFRAME_TYPE)) {
321-
if (!method.getReturnType().equals(DATAFRAME_TYPE)) {
321+
if (isDataFrame) {
322+
if (!isDataFrame(method.getGenericReturnType())) {
322323
throw new IllegalArgumentException("The return type of the 'transform' method should be 'DataFrame'");
323324
}
324325
} else {
@@ -356,6 +357,22 @@ private String generateClassName(String stageName) {
356357
return nameBuilder.toString();
357358
}
358359

360+
/**
361+
* Returns whether the given {@link Type} is a DataFrame type.
362+
*/
363+
private boolean isDataFrame(Type type) {
364+
if (DATAFRAME_TYPE == null) {
365+
return false;
366+
}
367+
if (type instanceof Class) {
368+
return ((Class<?>) type).isAssignableFrom(DATAFRAME_TYPE);
369+
}
370+
if (type instanceof ParameterizedType) {
371+
return isDataFrame(((ParameterizedType) type).getRawType());
372+
}
373+
return false;
374+
}
375+
359376
/**
360377
* Configuration object for the plugin
361378
*/

0 commit comments

Comments
 (0)