@@ -118,7 +118,7 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
118118 Method method = getTransformMethod (interpreter .getClassLoader (), className );
119119
120120 // If the method takes DataFrame, make sure it has input schema
121- if (method .getParameterTypes ()[0 ]. equals ( DATAFRAME_TYPE ) && stageConfigurer .getInputSchema () == null ) {
121+ if (isDataFrame ( method .getParameterTypes ()[0 ]) && stageConfigurer .getInputSchema () == null ) {
122122 throw new IllegalArgumentException ("Missing input schema for transformation using DataFrame" );
123123 }
124124
@@ -172,7 +172,7 @@ public void onEvent(SparkListenerEvent event) {
172172 interpreter .compile (generateSourceClass (className ));
173173 method = getTransformMethod (interpreter .getClassLoader (), className );
174174
175- isDataFrame = method .getParameterTypes ()[0 ]. equals ( DATAFRAME_TYPE );
175+ isDataFrame = isDataFrame ( method .getParameterTypes ()[0 ]);
176176 takeContext = method .getParameterTypes ().length == 2 ;
177177
178178 // Input schema shouldn't be null
@@ -302,9 +302,10 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
302302 }
303303
304304 Type [] parameterTypes = method .getGenericParameterTypes ();
305+ boolean isDataFrame = isDataFrame (parameterTypes [0 ]);
305306
306307 // The first parameter should be of type RDD[StructuredRecord] if it takes RDD
307- if (!parameterTypes [ 0 ]. equals ( DATAFRAME_TYPE ) ) {
308+ if (!isDataFrame ) {
308309 validateRDDType (parameterTypes [0 ],
309310 "The first parameter of the 'transform' method should have type as 'RDD[StructuredRecord]'" );
310311 }
@@ -317,8 +318,8 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
317318
318319 // The return type of the method must be RDD[StructuredRecord] if it takes RDD
319320 // Or it must be DataFrame if it takes DataFrame
320- if (parameterTypes [ 0 ]. equals ( DATAFRAME_TYPE ) ) {
321- if (!method .getReturnType (). equals ( DATAFRAME_TYPE )) {
321+ if (isDataFrame ) {
322+ if (!isDataFrame ( method .getGenericReturnType () )) {
322323 throw new IllegalArgumentException ("The return type of the 'transform' method should be 'DataFrame'" );
323324 }
324325 } else {
@@ -356,6 +357,22 @@ private String generateClassName(String stageName) {
356357 return nameBuilder .toString ();
357358 }
358359
360+ /**
361+ * Returns whether the given {@link Type} is a DataFrame type.
362+ */
363+ private boolean isDataFrame (Type type ) {
364+ if (DATAFRAME_TYPE == null ) {
365+ return false ;
366+ }
367+ if (type instanceof Class ) {
368+ return ((Class <?>) type ).isAssignableFrom (DATAFRAME_TYPE );
369+ }
370+ if (type instanceof ParameterizedType ) {
371+ return isDataFrame (((ParameterizedType ) type ).getRawType ());
372+ }
373+ return false ;
374+ }
375+
359376 /**
360377 * Configuration object for the plugin
361378 */
0 commit comments