Skip to content

Commit 846abc4

Browse files
committed
Merge branch 'release/2.0' into develop
2 parents 37d9b8f + cf1ae66 commit 846abc4

File tree

6 files changed

+123
-54
lines changed

6 files changed

+123
-54
lines changed

pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
<!-- properties for script build step that creates the config files for the artifacts -->
2929
<widgets.dir>widgets</widgets.dir>
3030
<docs.dir>docs</docs.dir>
31-
<data.pipeline.parent>system:cdap-data-pipeline[4.3.0-SNAPSHOT,5.0.0-SNAPSHOT)</data.pipeline.parent>
32-
<data.stream.parent>system:cdap-data-streams[4.3.0-SNAPSHOT,5.0.0-SNAPSHOT)</data.stream.parent>
31+
<data.pipeline.parent>system:cdap-data-pipeline[4.3.0-SNAPSHOT,6.0.0-SNAPSHOT)</data.pipeline.parent>
32+
<data.stream.parent>system:cdap-data-streams[4.3.0-SNAPSHOT,6.0.0-SNAPSHOT)</data.stream.parent>
3333
<!-- this is here because project.basedir evaluates to null in the script build step -->
3434
<main.basedir>${project.basedir}</main.basedir>
3535

src/main/java/co/cask/hydrator/plugin/spark/dynamic/ScalaSparkCompute.java

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,21 @@
3434
import org.apache.spark.api.java.JavaRDD;
3535
import org.apache.spark.api.java.function.Function;
3636
import org.apache.spark.rdd.RDD;
37-
import org.apache.spark.sql.DataFrame;
3837
import org.apache.spark.sql.Row;
3938
import org.apache.spark.sql.SQLContext;
39+
import org.apache.spark.sql.types.DataType;
4040
import org.apache.spark.sql.types.StructType;
41+
import org.slf4j.Logger;
42+
import org.slf4j.LoggerFactory;
4143

44+
import java.io.File;
4245
import java.io.IOException;
4346
import java.io.PrintWriter;
4447
import java.io.StringWriter;
4548
import java.lang.reflect.Method;
4649
import java.lang.reflect.ParameterizedType;
4750
import java.lang.reflect.Type;
51+
import java.nio.file.Files;
4852
import javax.annotation.Nullable;
4953

5054
/**
@@ -55,12 +59,15 @@
5559
@Description("Executes user-provided Spark code written in Scala that performs RDD to RDD transformation")
5660
public class ScalaSparkCompute extends SparkCompute<StructuredRecord, StructuredRecord> {
5761

62+
private static final Logger LOG = LoggerFactory.getLogger(ScalaSparkCompute.class);
63+
5864
private static final String CLASS_NAME_PREFIX = "co.cask.hydrator.plugin.spark.dynamic.generated.UserSparkCompute$";
65+
private static final Class<?> DATAFRAME_TYPE = getDataFrameType();
5966
private static final Class<?>[][] ACCEPTABLE_PARAMETER_TYPES = new Class<?>[][] {
6067
{ RDD.class, SparkExecutionPluginContext.class },
6168
{ RDD.class },
62-
{ DataFrame.class, SparkExecutionPluginContext.class},
63-
{ DataFrame.class }
69+
{ DATAFRAME_TYPE, SparkExecutionPluginContext.class},
70+
{ DATAFRAME_TYPE }
6471
};
6572

6673
private final ThreadLocal<SQLContext> sqlContextThreadLocal = new InheritableThreadLocal<>();
@@ -90,10 +97,16 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
9097
throw new IllegalArgumentException("Unable to parse output schema " + config.getSchema(), e);
9198
}
9299

93-
if (!config.containsMacro("scalaCode") && Boolean.TRUE.equals(config.getDeployCompile())) {
100+
if (!config.containsMacro("scalaCode") && !config.containsMacro("dependencies")
101+
&& Boolean.TRUE.equals(config.getDeployCompile())) {
94102
SparkInterpreter interpreter = SparkCompilers.createInterpreter();
95103
if (interpreter != null) {
104+
File dir = null;
96105
try {
106+
if (config.getDependencies() != null) {
107+
dir = Files.createTempDirectory("sparkprogram").toFile();
108+
SparkCompilers.addDependencies(dir, interpreter, config.getDependencies());
109+
}
97110
// We don't need the actual stage name as this only happen in deployment time for compilation check.
98111
String className = generateClassName("dummy");
99112
interpreter.compile(generateSourceClass(className));
@@ -102,12 +115,16 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
102115
Method method = getTransformMethod(interpreter.getClassLoader(), className);
103116

104117
// If the method takes DataFrame, make sure it has input schema
105-
if (method.getParameterTypes()[0].equals(DataFrame.class) && stageConfigurer.getInputSchema() == null) {
118+
if (method.getParameterTypes()[0].equals(DATAFRAME_TYPE) && stageConfigurer.getInputSchema() == null) {
106119
throw new IllegalArgumentException("Missing input schema for transformation using DataFrame");
107120
}
108121

109122
} catch (CompilationFailureException e) {
110123
throw new IllegalArgumentException(e.getMessage(), e);
124+
} catch (IOException e) {
125+
throw new RuntimeException(e);
126+
} finally {
127+
SparkCompilers.deleteDir(dir);
111128
}
112129
}
113130
}
@@ -117,9 +134,17 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
117134
public void initialize(SparkExecutionPluginContext context) throws Exception {
118135
String className = generateClassName(context.getStageName());
119136
interpreter = context.createSparkInterpreter();
137+
File dir = config.getDependencies() == null ? null : Files.createTempDirectory("sparkprogram").toFile();
138+
try {
139+
if (config.getDependencies() != null) {
140+
SparkCompilers.addDependencies(dir, interpreter, config.getDependencies());
141+
}
120142
interpreter.compile(generateSourceClass(className));
121143
method = getTransformMethod(interpreter.getClassLoader(), className);
122-
isDataFrame = method.getParameterTypes()[0].equals(DataFrame.class);
144+
} finally {
145+
SparkCompilers.deleteDir(dir);
146+
}
147+
isDataFrame = method.getParameterTypes()[0].equals(DATAFRAME_TYPE);
123148
takeContext = method.getParameterTypes().length == 2;
124149

125150
// Input schema shouldn't be null
@@ -154,18 +179,18 @@ public JavaRDD<StructuredRecord> transform(SparkExecutionPluginContext context,
154179
StructType rowType = DataFrames.toDataType(inputSchema);
155180
JavaRDD<Row> rowRDD = javaRDD.map(new RecordToRow(rowType));
156181

157-
DataFrame dataFrame = sqlContext.createDataFrame(rowRDD, rowType);
158-
DataFrame result = (DataFrame) (takeContext ?
159-
method.invoke(null, dataFrame, context) : method.invoke(null, dataFrame));
182+
Object dataFrame = sqlContext.createDataFrame(rowRDD, rowType);
183+
Object result = takeContext ? method.invoke(null, dataFrame, context) : method.invoke(null, dataFrame);
160184

161185
// Convert the DataFrame back to RDD<StructureRecord>
162186
Schema outputSchema = context.getOutputSchema();
163187
if (outputSchema == null) {
164188
// If there is no output schema configured, derive it from the DataFrame
165189
// Otherwise, assume the DataFrame has the correct schema already
166-
outputSchema = DataFrames.toSchema(result.schema());
190+
outputSchema = DataFrames.toSchema((DataType) invokeDataFrameMethod(result, "schema"));
167191
}
168-
return result.toJavaRDD().map(new RowToRecord(outputSchema));
192+
//noinspection unchecked
193+
return ((JavaRDD<Row>) invokeDataFrameMethod(result, "toJavaRDD")).map(new RowToRecord(outputSchema));
169194
}
170195

171196
private String generateSourceClass(String className) {
@@ -251,7 +276,7 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
251276
Type[] parameterTypes = method.getGenericParameterTypes();
252277

253278
// The first parameter should be of type RDD[StructuredRecord] if it takes RDD
254-
if (!parameterTypes[0].equals(DataFrame.class)) {
279+
if (!parameterTypes[0].equals(DATAFRAME_TYPE)) {
255280
validateRDDType(parameterTypes[0],
256281
"The first parameter of the 'transform' method should have type as 'RDD[StructuredRecord]'");
257282
}
@@ -264,8 +289,8 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
264289

265290
// The return type of the method must be RDD[StructuredRecord] if it takes RDD
266291
// Or it must be DataFrame if it takes DataFrame
267-
if (parameterTypes[0].equals(DataFrame.class)) {
268-
if (!method.getReturnType().equals(DataFrame.class)) {
292+
if (parameterTypes[0].equals(DATAFRAME_TYPE)) {
293+
if (!method.getReturnType().equals(DATAFRAME_TYPE)) {
269294
throw new IllegalArgumentException("The return type of the 'transform' method should be 'DataFrame'");
270295
}
271296
} else {
@@ -323,6 +348,16 @@ public static final class Config extends PluginConfig {
323348
@Macro
324349
private final String scalaCode;
325350

351+
@Description(
352+
"Extra dependencies for the Spark program. " +
353+
"It is a ',' separated list of URI for the location of dependency jars. " +
354+
"A path can be ended with an asterisk '*' as a wildcard, in which all files with extension '.jar' under the " +
355+
"parent path will be included."
356+
)
357+
@Macro
358+
@Nullable
359+
private final String dependencies;
360+
326361
@Description("The schema of output objects. If no schema is given, it is assumed that the output schema is " +
327362
"the same as the input schema.")
328363
@Nullable
@@ -334,9 +369,11 @@ public static final class Config extends PluginConfig {
334369
@Nullable
335370
private final Boolean deployCompile;
336371

337-
public Config(String scalaCode, @Nullable String schema, @Nullable Boolean deployCompile) {
372+
public Config(String scalaCode, @Nullable String schema, @Nullable String dependencies,
373+
@Nullable Boolean deployCompile) {
338374
this.scalaCode = scalaCode;
339375
this.schema = schema;
376+
this.dependencies = dependencies;
340377
this.deployCompile = deployCompile;
341378
}
342379

@@ -349,6 +386,11 @@ public String getSchema() {
349386
return schema;
350387
}
351388

389+
@Nullable
390+
public String getDependencies() {
391+
return dependencies;
392+
}
393+
352394
@Nullable
353395
public Boolean getDeployCompile() {
354396
return deployCompile;
@@ -388,4 +430,26 @@ public StructuredRecord call(Row row) throws Exception {
388430
return DataFrames.fromRow(row, schema);
389431
}
390432
}
433+
434+
@Nullable
435+
private static Class<?> getDataFrameType() {
436+
// For Spark1, it has the DataFrame class
437+
// For Spark2, there is no more DataFrame class, and it becomes Dataset<Row>
438+
try {
439+
return ScalaSparkCompute.class.getClassLoader().loadClass("org.apache.spark.sql.DataFrame");
440+
} catch (ClassNotFoundException e) {
441+
try {
442+
return ScalaSparkCompute.class.getClassLoader().loadClass("org.apache.spark.sql.Dataset");
443+
} catch (ClassNotFoundException e1) {
444+
LOG.warn("Failed to determine the type of Spark DataFrame. " +
445+
"DataFrame is not supported in the ScalaSparkCompute plugin.");
446+
return null;
447+
}
448+
}
449+
}
450+
451+
private static <T> T invokeDataFrameMethod(Object dataFrame, String methodName) throws Exception {
452+
//noinspection unchecked
453+
return (T) dataFrame.getClass().getMethod(methodName).invoke(dataFrame);
454+
}
391455
}

src/main/java/co/cask/hydrator/plugin/spark/dynamic/ScalaSparkProgram.java

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,7 @@
3535
import java.io.IOException;
3636
import java.lang.reflect.Method;
3737
import java.lang.reflect.Modifier;
38-
import java.nio.file.FileVisitResult;
3938
import java.nio.file.Files;
40-
import java.nio.file.Path;
41-
import java.nio.file.SimpleFileVisitor;
42-
import java.nio.file.attribute.BasicFileAttributes;
4339
import java.util.concurrent.Callable;
4440
import javax.annotation.Nullable;
4541

@@ -79,7 +75,7 @@ public ScalaSparkProgram(Config config) throws CompilationFailureException, IOEx
7975
getMethodCallable(interpreter.getClassLoader(), config.getMainClass(), null);
8076
}
8177
} finally {
82-
deleteDir(dir);
78+
SparkCompilers.deleteDir(dir);
8379
}
8480
} finally {
8581
interpreter.close();
@@ -98,7 +94,7 @@ public void run(JavaSparkExecutionContext sec) throws Exception {
9894
interpreter.compile(config.getScalaCode());
9995
getMethodCallable(interpreter.getClassLoader(), config.getMainClass(), sec).call();
10096
} finally {
101-
deleteDir(dir);
97+
SparkCompilers.deleteDir(dir);
10298
}
10399
}
104100

@@ -166,32 +162,6 @@ public Void call() throws Exception {
166162
}
167163
}
168164

169-
/**
170-
* Recursively delete a directory.
171-
*/
172-
public static void deleteDir(@Nullable File dir) {
173-
if (dir == null) {
174-
return;
175-
}
176-
try {
177-
Files.walkFileTree(dir.toPath(), new SimpleFileVisitor<Path>() {
178-
@Override
179-
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
180-
Files.deleteIfExists(file);
181-
return FileVisitResult.CONTINUE;
182-
}
183-
184-
@Override
185-
public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
186-
Files.deleteIfExists(dir);
187-
return FileVisitResult.CONTINUE;
188-
}
189-
});
190-
} catch (IOException e) {
191-
LOG.warn("Failed to cleanup temporary directory {}", dir, e);
192-
}
193-
}
194-
195165
/**
196166
* Plugin configuration
197167
*/

src/main/java/co/cask/hydrator/plugin/spark/dynamic/SparkCompilers.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import org.apache.hadoop.fs.LocatedFileStatus;
2525
import org.apache.hadoop.fs.Path;
2626
import org.apache.hadoop.fs.RemoteIterator;
27+
import org.slf4j.Logger;
28+
import org.slf4j.LoggerFactory;
2729
import scala.Function0;
2830
import scala.Option$;
2931
import scala.collection.JavaConversions;
@@ -44,7 +46,10 @@
4446
import java.net.URI;
4547
import java.net.URISyntaxException;
4648
import java.net.URL;
49+
import java.nio.file.FileVisitResult;
4750
import java.nio.file.Files;
51+
import java.nio.file.SimpleFileVisitor;
52+
import java.nio.file.attribute.BasicFileAttributes;
4853
import java.util.ArrayList;
4954
import java.util.Collection;
5055
import java.util.Collections;
@@ -57,6 +62,8 @@
5762
*/
5863
public final class SparkCompilers {
5964

65+
private static final Logger LOG = LoggerFactory.getLogger(SparkCompilers.class);
66+
6067
private static final FilenameFilter JAR_FILE_FILTER = new FilenameFilter() {
6168
@Override
6269
public boolean accept(File dir, String name) {
@@ -214,4 +221,30 @@ private static void copyPathAndAdd(FileSystem fs, Path from, File dir, Collectio
214221
private SparkCompilers() {
215222
// no-op
216223
}
224+
225+
/**
226+
* Recursively delete a directory.
227+
*/
228+
public static void deleteDir(@Nullable File dir) {
229+
if (dir == null) {
230+
return;
231+
}
232+
try {
233+
Files.walkFileTree(dir.toPath(), new SimpleFileVisitor<java.nio.file.Path>() {
234+
@Override
235+
public FileVisitResult visitFile(java.nio.file.Path file, BasicFileAttributes attrs) throws IOException {
236+
Files.deleteIfExists(file);
237+
return FileVisitResult.CONTINUE;
238+
}
239+
240+
@Override
241+
public FileVisitResult postVisitDirectory(java.nio.file.Path dir, IOException exc) throws IOException {
242+
Files.deleteIfExists(dir);
243+
return FileVisitResult.CONTINUE;
244+
}
245+
});
246+
} catch (IOException e) {
247+
LOG.warn("Failed to cleanup temporary directory {}", dir, e);
248+
}
249+
}
217250
}

widgets/ScalaSparkCompute-sparkcompute.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
"default": "/**\n * Transforms the provided input Apache Spark RDD or DataFrame into another RDD or DataFrame.\n *\n * The input DataFrame has the same schema as the input schema to this stage and the transform method should return a DataFrame that has the same schema as the output schema setup for this stage.\n * To emit logs, use: \n * import org.slf4j.LoggerFactory\n * val logger = LoggerFactory.getLogger('mylogger')\n * logger.info('Logging')\n *\n *\n * @param input the input DataFrame which has the same schema as the input schema to this stage.\n * @param context a SparkExecutionPluginContext object that can be used to emit zero or more records (using the emitter.emit() method) or errors (using the emitter.emitError() method) \n * @param context an object that provides access to:\n * 1. CDAP Datasets and Streams - context.fromDataset('counts'); or context.fromStream('input');\n * 2. Original Spark Context - context.getSparkContext();\n * 3. Runtime Arguments - context.getArguments.get('priceThreshold')\n */\ndef transform(df: DataFrame, context: SparkExecutionPluginContext) : DataFrame = {\n df\n}"
1616
}
1717
},
18+
{
19+
"widget-type": "csv",
20+
"label": "Dependencies",
21+
"name": "dependencies"
22+
},
1823
{
1924
"widget-type": "select",
2025
"label": "Compile at Deployment Time",

widgets/ScalaSparkProgram-sparkprogram.json

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,9 @@
2323
}
2424
},
2525
{
26-
"widget-type": "dsv",
26+
"widget-type": "csv",
2727
"label": "Dependencies",
28-
"name": "dependencies",
29-
"widget-attributes": {
30-
"delimiter": ","
31-
}
28+
"name": "dependencies"
3229
},
3330
{
3431
"widget-type": "select",

0 commit comments

Comments
 (0)