Skip to content

Commit 5dadda7

Browse files
committed
Merge remote-tracking branch 'origin/release/2.2' into merge-2-2
2 parents 848217c + 72c5aa0 commit 5dadda7

File tree

3 files changed

+94
-32
lines changed

3 files changed

+94
-32
lines changed

src/main/java/io/cdap/plugin/spark/dynamic/ScalaSparkCodeExecutor.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ public void onEvent(SparkListenerEvent event) {
165165
}
166166
}
167167

168+
public boolean isDataFrame() {
169+
return isDataFrame;
170+
}
171+
168172
/**
169173
* Execute interpreted code on the given RDD.
170174
*/
@@ -177,10 +181,10 @@ public Object execute(SparkExecutionPluginContext context,
177181
if (!isDataFrame) {
178182
if (takeContext) {
179183
//noinspection unchecked
180-
return ((RDD<StructuredRecord>) method.invoke(null, javaRDD.rdd(), context)).toJavaRDD();
184+
return method.invoke(null, javaRDD.rdd(), context);
181185
} else {
182186
//noinspection unchecked
183-
return ((RDD<StructuredRecord>) method.invoke(null, javaRDD.rdd())).toJavaRDD();
187+
return method.invoke(null, javaRDD.rdd());
184188
}
185189
}
186190

src/main/java/io/cdap/plugin/spark/dynamic/ScalaSparkCompute.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import io.cdap.cdap.etl.api.batch.SparkExecutionPluginContext;
3131
import org.apache.spark.api.java.JavaRDD;
3232
import org.apache.spark.api.java.function.Function;
33+
import org.apache.spark.rdd.RDD;
3334
import org.apache.spark.sql.Row;
3435
import org.apache.spark.sql.types.DataType;
3536

@@ -48,6 +49,7 @@ public class ScalaSparkCompute extends SparkCompute<StructuredRecord, Structured
4849
// A strong reference is needed to keep the compiled classes around
4950
@SuppressWarnings("FieldCanBeLocal")
5051
private transient ScalaSparkCodeExecutor codeExecutor;
52+
private transient boolean isRDD;
5153

5254
public ScalaSparkCompute(Config config) {
5355
this.config = config;
@@ -77,13 +79,19 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
7779
public void initialize(SparkExecutionPluginContext context) throws Exception {
7880
codeExecutor = new ScalaSparkCodeExecutor(config.getScalaCode(), config.getDependencies(), "transform", false);
7981
codeExecutor.initialize(context);
82+
isRDD = !codeExecutor.isDataFrame();
8083
}
8184

8285
@Override
8386
public JavaRDD<StructuredRecord> transform(SparkExecutionPluginContext context,
8487
JavaRDD<StructuredRecord> javaRDD) throws Exception {
8588
Object result = codeExecutor.execute(context, javaRDD);
8689

90+
if (isRDD) {
91+
//noinspection unchecked
92+
return ((RDD<StructuredRecord>) result).toJavaRDD();
93+
}
94+
8795
// Convert the DataFrame back to RDD<StructureRecord>
8896
Schema outputSchema = context.getOutputSchema();
8997
if (outputSchema == null) {
@@ -176,7 +184,7 @@ public RowToRecord(Schema schema) {
176184
}
177185

178186
@Override
179-
public StructuredRecord call(Row row) throws Exception {
187+
public StructuredRecord call(Row row) {
180188
return DataFrames.fromRow(row, schema);
181189
}
182190
}

src/test/java/io/cdap/plugin/spark/dynamic/ScalaSparkTest.java

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
import java.util.List;
7373
import java.util.Map;
7474
import java.util.Set;
75+
import java.util.UUID;
7576
import java.util.concurrent.TimeUnit;
7677

7778
/**
@@ -160,8 +161,7 @@ public void testScalaProgram() throws Exception {
160161
Map<String, String> runtimeArgs = new HashMap<>(RuntimeArguments.addScope(Scope.DATASET, "text", inputArgs));
161162

162163
WorkflowManager workflowManager = appManager.getWorkflowManager(SmartWorkflow.NAME);
163-
workflowManager.start(runtimeArgs);
164-
workflowManager.waitForRun(ProgramRunStatus.COMPLETED, 5, TimeUnit.MINUTES);
164+
workflowManager.startAndWaitForRun(runtimeArgs, ProgramRunStatus.COMPLETED, 5, TimeUnit.MINUTES);
165165

166166
// Validate the result
167167
KeyValueTable kvTable = this.<KeyValueTable>getDataset("kvTable").get();
@@ -269,7 +269,40 @@ public void testScalaSparkProgramClosure() throws Exception {
269269
}
270270

271271
@Test
272-
public void testScalaSparkCompute() throws Exception {
272+
public void testScalaSparkComputeDataFrame() throws Exception {
273+
StringWriter codeWriter = new StringWriter();
274+
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
275+
printer.println("def transform(df: DataFrame) : DataFrame = {");
276+
printer.println(" val splitted = df.explode(\"body\", \"word\") { ");
277+
printer.println(" line: String => line.split(\"\\\\s+\")");
278+
printer.println(" }");
279+
printer.println(" splitted.registerTempTable(\"splitted\")");
280+
printer.println(" splitted.sqlContext.sql(\"SELECT word, count(*) as count FROM splitted GROUP BY word\")");
281+
printer.println("}");
282+
}
283+
284+
testWordCountCompute(codeWriter.toString());
285+
}
286+
287+
@Test
288+
public void testScalaSparkComputeRDD() throws Exception {
289+
StringWriter codeWriter = new StringWriter();
290+
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
291+
printer.println(
292+
"def transform(rdd: RDD[StructuredRecord], context:SparkExecutionPluginContext) : RDD[StructuredRecord] = {");
293+
printer.println(" val schema = context.getOutputSchema");
294+
printer.println(" rdd");
295+
printer.println(" .flatMap(_.get[String](\"body\").split(\"\\\\s+\"))");
296+
printer.println(" .map(s => (s, 1L))");
297+
printer.println(" .reduceByKey(_ + _)");
298+
printer.println(" .map(t => StructuredRecord.builder(schema).set(\"word\", t._1).set(\"count\", t._2).build)");
299+
printer.println("}");
300+
}
301+
302+
testWordCountCompute(codeWriter.toString());
303+
}
304+
305+
private void testWordCountCompute(String code) throws Exception {
273306
Schema inputSchema = Schema.recordOf(
274307
"input",
275308
Schema.Field.of("body", Schema.nullableOf(Schema.of(Schema.Type.STRING)))
@@ -281,25 +314,17 @@ public void testScalaSparkCompute() throws Exception {
281314
Schema.Field.of("count", Schema.nullableOf(Schema.of(Schema.Type.LONG)))
282315
);
283316

284-
StringWriter codeWriter = new StringWriter();
285-
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
286-
printer.println("def transform(df: DataFrame) : DataFrame = {");
287-
printer.println(" val splitted = df.explode(\"body\", \"word\") { ");
288-
printer.println(" line: String => line.split(\"\\\\s+\")");
289-
printer.println(" }");
290-
printer.println(" splitted.registerTempTable(\"splitted\")");
291-
printer.println(" splitted.sqlContext.sql(\"SELECT word, count(*) as count FROM splitted GROUP BY word\")");
292-
printer.println("}");
293-
}
317+
String inputTable = UUID.randomUUID().toString();
318+
String outputTable = UUID.randomUUID().toString();
294319

295320
// Pipeline configuration
296321
ETLBatchConfig etlConfig = ETLBatchConfig.builder("* * * * *")
297-
.addStage(new ETLStage("source", MockSource.getPlugin("singleInput", inputSchema)))
322+
.addStage(new ETLStage("source", MockSource.getPlugin(inputTable, inputSchema)))
298323
.addStage(new ETLStage("compute", new ETLPlugin("ScalaSparkCompute", SparkCompute.PLUGIN_TYPE, ImmutableMap.of(
299-
"scalaCode", codeWriter.toString(),
324+
"scalaCode", code,
300325
"schema", computeSchema.toString()
301326
))))
302-
.addStage(new ETLStage("sink", MockSink.getPlugin("singleOutput")))
327+
.addStage(new ETLStage("sink", MockSink.getPlugin(outputTable)))
303328
.addConnection("source", "compute")
304329
.addConnection("compute", "sink")
305330
.build();
@@ -308,11 +333,11 @@ public void testScalaSparkCompute() throws Exception {
308333
ArtifactSummary artifactSummary = new ArtifactSummary(DATAPIPELINE_ARTIFACT_ID.getArtifact(),
309334
DATAPIPELINE_ARTIFACT_ID.getVersion());
310335
AppRequest<ETLBatchConfig> appRequest = new AppRequest<>(artifactSummary, etlConfig);
311-
ApplicationId appId = NamespaceId.DEFAULT.app("ScalaSparkComputeApp");
336+
ApplicationId appId = NamespaceId.DEFAULT.app(UUID.randomUUID().toString());
312337
ApplicationManager appManager = deployApplication(appId, appRequest);
313338

314339
// write records to source
315-
DataSetManager<Table> inputManager = getDataset(NamespaceId.DEFAULT.dataset("singleInput"));
340+
DataSetManager<Table> inputManager = getDataset(NamespaceId.DEFAULT.dataset(inputTable));
316341
List<StructuredRecord> inputRecords = new ArrayList<>();
317342
for (int i = 0; i < 10; i++) {
318343
inputRecords.add(StructuredRecord.builder(inputSchema).set("body", "Line " + i).build());
@@ -326,7 +351,7 @@ public void testScalaSparkCompute() throws Exception {
326351

327352
// Verify result written to sink.
328353
// It has two fields, word and count.
329-
DataSetManager<Table> sinkManager = getDataset("singleOutput");
354+
DataSetManager<Table> sinkManager = getDataset(outputTable);
330355
Map<String, StructuredRecord> wordCounts =
331356
Maps.uniqueIndex(Sets.newHashSet(MockSink.readOutput(sinkManager)), new Function<StructuredRecord, String>() {
332357
@Override
@@ -343,13 +368,28 @@ public String apply(StructuredRecord record) {
343368
}
344369

345370
@Test
346-
public void testScalaSparkSink() throws Exception {
347-
Schema inputSchema = Schema.recordOf(
348-
"input",
349-
Schema.Field.of("body", Schema.nullableOf(Schema.of(Schema.Type.STRING)))
350-
);
371+
public void testScalaSparkSinkRDD() throws Exception {
372+
File testFolder = TEMP_FOLDER.newFolder("scalaSinkRDDOutput");
373+
File outputFolder = new File(testFolder, "output");
374+
StringWriter codeWriter = new StringWriter();
375+
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
376+
printer.println(
377+
"def sink(rdd: RDD[StructuredRecord], context:SparkExecutionPluginContext) : Unit = {");
378+
printer.println(" val schema = context.getOutputSchema");
379+
printer.println(" rdd");
380+
printer.println(" .flatMap(_.get[String](\"body\").split(\"\\\\s+\"))");
381+
printer.println(" .map(s => (s, 1L))");
382+
printer.println(" .reduceByKey(_ + _)");
383+
printer.println(" .map(t => t._1 + \" \" + t._2)");
384+
printer.println(" .saveAsTextFile(\"" + outputFolder.getAbsolutePath() + "\")");
385+
printer.println("}");
386+
}
387+
testWordCountSink(codeWriter.toString(), outputFolder);
388+
}
351389

352-
File testFolder = TEMP_FOLDER.newFolder("scalaSinkOutput");
390+
@Test
391+
public void testScalaSparkSinkDataFrame() throws Exception {
392+
File testFolder = TEMP_FOLDER.newFolder("scalaSinkDataframeOutput");
353393
File outputFolder = new File(testFolder, "output");
354394
StringWriter codeWriter = new StringWriter();
355395
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
@@ -363,24 +403,34 @@ public void testScalaSparkSink() throws Exception {
363403
printer.println(" out.write.format(\"text\").save(\"" + outputFolder.getAbsolutePath() + "\")");
364404
printer.println("}");
365405
}
406+
testWordCountSink(codeWriter.toString(), outputFolder);
407+
}
408+
409+
private void testWordCountSink(String code, File outputFolder) throws Exception {
410+
Schema inputSchema = Schema.recordOf(
411+
"input",
412+
Schema.Field.of("body", Schema.nullableOf(Schema.of(Schema.Type.STRING)))
413+
);
414+
415+
String inputTable = UUID.randomUUID().toString();
366416

367417
// Pipeline configuration
368418
ETLBatchConfig etlConfig = ETLBatchConfig.builder("* * * * *")
369-
.addStage(new ETLStage("source", MockSource.getPlugin("sinkInput", inputSchema)))
419+
.addStage(new ETLStage("source", MockSource.getPlugin(inputTable, inputSchema)))
370420
.addStage(new ETLStage("sink", new ETLPlugin("ScalaSparkSink", SparkSink.PLUGIN_TYPE,
371-
ImmutableMap.of("scalaCode", codeWriter.toString()))))
421+
ImmutableMap.of("scalaCode", code))))
372422
.addConnection("source", "sink")
373423
.build();
374424

375425
// Deploy the pipeline
376426
ArtifactSummary artifactSummary = new ArtifactSummary(DATAPIPELINE_ARTIFACT_ID.getArtifact(),
377427
DATAPIPELINE_ARTIFACT_ID.getVersion());
378428
AppRequest<ETLBatchConfig> appRequest = new AppRequest<>(artifactSummary, etlConfig);
379-
ApplicationId appId = NamespaceId.DEFAULT.app("ScalaSparkSinkApp");
429+
ApplicationId appId = NamespaceId.DEFAULT.app(UUID.randomUUID().toString());
380430
ApplicationManager appManager = deployApplication(appId, appRequest);
381431

382432
// write records to source
383-
DataSetManager<Table> inputManager = getDataset(NamespaceId.DEFAULT.dataset("sinkInput"));
433+
DataSetManager<Table> inputManager = getDataset(NamespaceId.DEFAULT.dataset(inputTable));
384434
List<StructuredRecord> inputRecords = new ArrayList<>();
385435
for (int i = 0; i < 10; i++) {
386436
inputRecords.add(StructuredRecord.builder(inputSchema).set("body", "Line " + i).build());

0 commit comments

Comments
 (0)