3434import org .apache .spark .api .java .JavaRDD ;
3535import org .apache .spark .api .java .function .Function ;
3636import org .apache .spark .rdd .RDD ;
37- import org .apache .spark .sql .DataFrame ;
3837import org .apache .spark .sql .Row ;
3938import org .apache .spark .sql .SQLContext ;
39+ import org .apache .spark .sql .types .DataType ;
4040import org .apache .spark .sql .types .StructType ;
41+ import org .slf4j .Logger ;
42+ import org .slf4j .LoggerFactory ;
4143
44+ import java .io .File ;
4245import java .io .IOException ;
4346import java .io .PrintWriter ;
4447import java .io .StringWriter ;
4548import java .lang .reflect .Method ;
4649import java .lang .reflect .ParameterizedType ;
4750import java .lang .reflect .Type ;
51+ import java .nio .file .Files ;
4852import javax .annotation .Nullable ;
4953
5054/**
5559@ Description ("Executes user-provided Spark code written in Scala that performs RDD to RDD transformation" )
5660public class ScalaSparkCompute extends SparkCompute <StructuredRecord , StructuredRecord > {
5761
62+ private static final Logger LOG = LoggerFactory .getLogger (ScalaSparkCompute .class );
63+
5864 private static final String CLASS_NAME_PREFIX = "co.cask.hydrator.plugin.spark.dynamic.generated.UserSparkCompute$" ;
65+ private static final Class <?> DATAFRAME_TYPE = getDataFrameType ();
5966 private static final Class <?>[][] ACCEPTABLE_PARAMETER_TYPES = new Class <?>[][] {
6067 { RDD .class , SparkExecutionPluginContext .class },
6168 { RDD .class },
62- { DataFrame . class , SparkExecutionPluginContext .class },
63- { DataFrame . class }
69+ { DATAFRAME_TYPE , SparkExecutionPluginContext .class },
70+ { DATAFRAME_TYPE }
6471 };
6572
6673 private final ThreadLocal <SQLContext > sqlContextThreadLocal = new InheritableThreadLocal <>();
@@ -90,10 +97,16 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
9097 throw new IllegalArgumentException ("Unable to parse output schema " + config .getSchema (), e );
9198 }
9299
93- if (!config .containsMacro ("scalaCode" ) && Boolean .TRUE .equals (config .getDeployCompile ())) {
100+ if (!config .containsMacro ("scalaCode" ) && !config .containsMacro ("dependencies" )
101+ && Boolean .TRUE .equals (config .getDeployCompile ())) {
94102 SparkInterpreter interpreter = SparkCompilers .createInterpreter ();
95103 if (interpreter != null ) {
104+ File dir = null ;
96105 try {
106+ if (config .getDependencies () != null ) {
107+ dir = Files .createTempDirectory ("sparkprogram" ).toFile ();
108+ SparkCompilers .addDependencies (dir , interpreter , config .getDependencies ());
109+ }
97110 // We don't need the actual stage name as this only happen in deployment time for compilation check.
98111 String className = generateClassName ("dummy" );
99112 interpreter .compile (generateSourceClass (className ));
@@ -102,12 +115,16 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
102115 Method method = getTransformMethod (interpreter .getClassLoader (), className );
103116
104117 // If the method takes DataFrame, make sure it has input schema
105- if (method .getParameterTypes ()[0 ].equals (DataFrame . class ) && stageConfigurer .getInputSchema () == null ) {
118+ if (method .getParameterTypes ()[0 ].equals (DATAFRAME_TYPE ) && stageConfigurer .getInputSchema () == null ) {
106119 throw new IllegalArgumentException ("Missing input schema for transformation using DataFrame" );
107120 }
108121
109122 } catch (CompilationFailureException e ) {
110123 throw new IllegalArgumentException (e .getMessage (), e );
124+ } catch (IOException e ) {
125+ throw new RuntimeException (e );
126+ } finally {
127+ SparkCompilers .deleteDir (dir );
111128 }
112129 }
113130 }
@@ -117,9 +134,17 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
117134 public void initialize (SparkExecutionPluginContext context ) throws Exception {
118135 String className = generateClassName (context .getStageName ());
119136 interpreter = context .createSparkInterpreter ();
137+ File dir = config .getDependencies () == null ? null : Files .createTempDirectory ("sparkprogram" ).toFile ();
138+ try {
139+ if (config .getDependencies () != null ) {
140+ SparkCompilers .addDependencies (dir , interpreter , config .getDependencies ());
141+ }
120142 interpreter .compile (generateSourceClass (className ));
121143 method = getTransformMethod (interpreter .getClassLoader (), className );
122- isDataFrame = method .getParameterTypes ()[0 ].equals (DataFrame .class );
144+ } finally {
145+ SparkCompilers .deleteDir (dir );
146+ }
147+ isDataFrame = method .getParameterTypes ()[0 ].equals (DATAFRAME_TYPE );
123148 takeContext = method .getParameterTypes ().length == 2 ;
124149
125150 // Input schema shouldn't be null
@@ -154,18 +179,18 @@ public JavaRDD<StructuredRecord> transform(SparkExecutionPluginContext context,
154179 StructType rowType = DataFrames .toDataType (inputSchema );
155180 JavaRDD <Row > rowRDD = javaRDD .map (new RecordToRow (rowType ));
156181
157- DataFrame dataFrame = sqlContext .createDataFrame (rowRDD , rowType );
158- DataFrame result = (DataFrame ) (takeContext ?
159- method .invoke (null , dataFrame , context ) : method .invoke (null , dataFrame ));
182+ Object dataFrame = sqlContext .createDataFrame (rowRDD , rowType );
183+ Object result = takeContext ? method .invoke (null , dataFrame , context ) : method .invoke (null , dataFrame );
160184
161185 // Convert the DataFrame back to RDD<StructureRecord>
162186 Schema outputSchema = context .getOutputSchema ();
163187 if (outputSchema == null ) {
164188 // If there is no output schema configured, derive it from the DataFrame
165189 // Otherwise, assume the DataFrame has the correct schema already
166- outputSchema = DataFrames .toSchema (result . schema ( ));
190+ outputSchema = DataFrames .toSchema (( DataType ) invokeDataFrameMethod ( result , " schema" ));
167191 }
168- return result .toJavaRDD ().map (new RowToRecord (outputSchema ));
192+ //noinspection unchecked
193+ return ((JavaRDD <Row >) invokeDataFrameMethod (result , "toJavaRDD" )).map (new RowToRecord (outputSchema ));
169194 }
170195
171196 private String generateSourceClass (String className ) {
@@ -251,7 +276,7 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
251276 Type [] parameterTypes = method .getGenericParameterTypes ();
252277
253278 // The first parameter should be of type RDD[StructuredRecord] if it takes RDD
254- if (!parameterTypes [0 ].equals (DataFrame . class )) {
279+ if (!parameterTypes [0 ].equals (DATAFRAME_TYPE )) {
255280 validateRDDType (parameterTypes [0 ],
256281 "The first parameter of the 'transform' method should have type as 'RDD[StructuredRecord]'" );
257282 }
@@ -264,8 +289,8 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
264289
265290 // The return type of the method must be RDD[StructuredRecord] if it takes RDD
266291 // Or it must be DataFrame if it takes DataFrame
267- if (parameterTypes [0 ].equals (DataFrame . class )) {
268- if (!method .getReturnType ().equals (DataFrame . class )) {
292+ if (parameterTypes [0 ].equals (DATAFRAME_TYPE )) {
293+ if (!method .getReturnType ().equals (DATAFRAME_TYPE )) {
269294 throw new IllegalArgumentException ("The return type of the 'transform' method should be 'DataFrame'" );
270295 }
271296 } else {
@@ -323,6 +348,16 @@ public static final class Config extends PluginConfig {
323348 @ Macro
324349 private final String scalaCode ;
325350
351+ @ Description (
352+ "Extra dependencies for the Spark program. " +
353+ "It is a ',' separated list of URI for the location of dependency jars. " +
354+ "A path can be ended with an asterisk '*' as a wildcard, in which all files with extension '.jar' under the " +
355+ "parent path will be included."
356+ )
357+ @ Macro
358+ @ Nullable
359+ private final String dependencies ;
360+
326361 @ Description ("The schema of output objects. If no schema is given, it is assumed that the output schema is " +
327362 "the same as the input schema." )
328363 @ Nullable
@@ -334,9 +369,11 @@ public static final class Config extends PluginConfig {
334369 @ Nullable
335370 private final Boolean deployCompile ;
336371
337- public Config (String scalaCode , @ Nullable String schema , @ Nullable Boolean deployCompile ) {
372+ public Config (String scalaCode , @ Nullable String schema , @ Nullable String dependencies ,
373+ @ Nullable Boolean deployCompile ) {
338374 this .scalaCode = scalaCode ;
339375 this .schema = schema ;
376+ this .dependencies = dependencies ;
340377 this .deployCompile = deployCompile ;
341378 }
342379
@@ -349,6 +386,11 @@ public String getSchema() {
349386 return schema ;
350387 }
351388
389+ @ Nullable
390+ public String getDependencies () {
391+ return dependencies ;
392+ }
393+
352394 @ Nullable
353395 public Boolean getDeployCompile () {
354396 return deployCompile ;
@@ -388,4 +430,26 @@ public StructuredRecord call(Row row) throws Exception {
388430 return DataFrames .fromRow (row , schema );
389431 }
390432 }
433+
434+ @ Nullable
435+ private static Class <?> getDataFrameType () {
436+ // For Spark1, it has the DataFrame class
437+ // For Spark2, there is no more DataFrame class, and it becomes Dataset<Row>
438+ try {
439+ return ScalaSparkCompute .class .getClassLoader ().loadClass ("org.apache.spark.sql.DataFrame" );
440+ } catch (ClassNotFoundException e ) {
441+ try {
442+ return ScalaSparkCompute .class .getClassLoader ().loadClass ("org.apache.spark.sql.Dataset" );
443+ } catch (ClassNotFoundException e1 ) {
444+ LOG .warn ("Failed to determine the type of Spark DataFrame. " +
445+ "DataFrame is not supported in the ScalaSparkCompute plugin." );
446+ return null ;
447+ }
448+ }
449+ }
450+
451+ private static <T > T invokeDataFrameMethod (Object dataFrame , String methodName ) throws Exception {
452+ //noinspection unchecked
453+ return (T ) dataFrame .getClass ().getMethod (methodName ).invoke (dataFrame );
454+ }
391455}
0 commit comments