7272import java .util .List ;
7373import java .util .Map ;
7474import java .util .Set ;
75+ import java .util .UUID ;
7576import java .util .concurrent .TimeUnit ;
7677
7778/**
@@ -160,8 +161,7 @@ public void testScalaProgram() throws Exception {
160161 Map <String , String > runtimeArgs = new HashMap <>(RuntimeArguments .addScope (Scope .DATASET , "text" , inputArgs ));
161162
162163 WorkflowManager workflowManager = appManager .getWorkflowManager (SmartWorkflow .NAME );
163- workflowManager .start (runtimeArgs );
164- workflowManager .waitForRun (ProgramRunStatus .COMPLETED , 5 , TimeUnit .MINUTES );
164+ workflowManager .startAndWaitForRun (runtimeArgs , ProgramRunStatus .COMPLETED , 5 , TimeUnit .MINUTES );
165165
166166 // Validate the result
167167 KeyValueTable kvTable = this .<KeyValueTable >getDataset ("kvTable" ).get ();
@@ -269,7 +269,40 @@ public void testScalaSparkProgramClosure() throws Exception {
269269 }
270270
271271 @ Test
272- public void testScalaSparkCompute () throws Exception {
272+ public void testScalaSparkComputeDataFrame () throws Exception {
273+ StringWriter codeWriter = new StringWriter ();
274+ try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
275+ printer .println ("def transform(df: DataFrame) : DataFrame = {" );
276+ printer .println (" val splitted = df.explode(\" body\" , \" word\" ) { " );
277+ printer .println (" line: String => line.split(\" \\ \\ s+\" )" );
278+ printer .println (" }" );
279+ printer .println (" splitted.registerTempTable(\" splitted\" )" );
280+ printer .println (" splitted.sqlContext.sql(\" SELECT word, count(*) as count FROM splitted GROUP BY word\" )" );
281+ printer .println ("}" );
282+ }
283+
284+ testWordCountCompute (codeWriter .toString ());
285+ }
286+
287+ @ Test
288+ public void testScalaSparkComputeRDD () throws Exception {
289+ StringWriter codeWriter = new StringWriter ();
290+ try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
291+ printer .println (
292+ "def transform(rdd: RDD[StructuredRecord], context:SparkExecutionPluginContext) : RDD[StructuredRecord] = {" );
293+ printer .println (" val schema = context.getOutputSchema" );
294+ printer .println (" rdd" );
295+ printer .println (" .flatMap(_.get[String](\" body\" ).split(\" \\ \\ s+\" ))" );
296+ printer .println (" .map(s => (s, 1L))" );
297+ printer .println (" .reduceByKey(_ + _)" );
298+ printer .println (" .map(t => StructuredRecord.builder(schema).set(\" word\" , t._1).set(\" count\" , t._2).build)" );
299+ printer .println ("}" );
300+ }
301+
302+ testWordCountCompute (codeWriter .toString ());
303+ }
304+
305+ private void testWordCountCompute (String code ) throws Exception {
273306 Schema inputSchema = Schema .recordOf (
274307 "input" ,
275308 Schema .Field .of ("body" , Schema .nullableOf (Schema .of (Schema .Type .STRING )))
@@ -281,25 +314,17 @@ public void testScalaSparkCompute() throws Exception {
281314 Schema .Field .of ("count" , Schema .nullableOf (Schema .of (Schema .Type .LONG )))
282315 );
283316
284- StringWriter codeWriter = new StringWriter ();
285- try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
286- printer .println ("def transform(df: DataFrame) : DataFrame = {" );
287- printer .println (" val splitted = df.explode(\" body\" , \" word\" ) { " );
288- printer .println (" line: String => line.split(\" \\ \\ s+\" )" );
289- printer .println (" }" );
290- printer .println (" splitted.registerTempTable(\" splitted\" )" );
291- printer .println (" splitted.sqlContext.sql(\" SELECT word, count(*) as count FROM splitted GROUP BY word\" )" );
292- printer .println ("}" );
293- }
317+ String inputTable = UUID .randomUUID ().toString ();
318+ String outputTable = UUID .randomUUID ().toString ();
294319
295320 // Pipeline configuration
296321 ETLBatchConfig etlConfig = ETLBatchConfig .builder ("* * * * *" )
297- .addStage (new ETLStage ("source" , MockSource .getPlugin ("singleInput" , inputSchema )))
322+ .addStage (new ETLStage ("source" , MockSource .getPlugin (inputTable , inputSchema )))
298323 .addStage (new ETLStage ("compute" , new ETLPlugin ("ScalaSparkCompute" , SparkCompute .PLUGIN_TYPE , ImmutableMap .of (
299- "scalaCode" , codeWriter . toString () ,
324+ "scalaCode" , code ,
300325 "schema" , computeSchema .toString ()
301326 ))))
302- .addStage (new ETLStage ("sink" , MockSink .getPlugin ("singleOutput" )))
327+ .addStage (new ETLStage ("sink" , MockSink .getPlugin (outputTable )))
303328 .addConnection ("source" , "compute" )
304329 .addConnection ("compute" , "sink" )
305330 .build ();
@@ -308,11 +333,11 @@ public void testScalaSparkCompute() throws Exception {
308333 ArtifactSummary artifactSummary = new ArtifactSummary (DATAPIPELINE_ARTIFACT_ID .getArtifact (),
309334 DATAPIPELINE_ARTIFACT_ID .getVersion ());
310335 AppRequest <ETLBatchConfig > appRequest = new AppRequest <>(artifactSummary , etlConfig );
311- ApplicationId appId = NamespaceId .DEFAULT .app ("ScalaSparkComputeApp" );
336+ ApplicationId appId = NamespaceId .DEFAULT .app (UUID . randomUUID (). toString () );
312337 ApplicationManager appManager = deployApplication (appId , appRequest );
313338
314339 // write records to source
315- DataSetManager <Table > inputManager = getDataset (NamespaceId .DEFAULT .dataset ("singleInput" ));
340+ DataSetManager <Table > inputManager = getDataset (NamespaceId .DEFAULT .dataset (inputTable ));
316341 List <StructuredRecord > inputRecords = new ArrayList <>();
317342 for (int i = 0 ; i < 10 ; i ++) {
318343 inputRecords .add (StructuredRecord .builder (inputSchema ).set ("body" , "Line " + i ).build ());
@@ -326,7 +351,7 @@ public void testScalaSparkCompute() throws Exception {
326351
327352 // Verify result written to sink.
328353 // It has two fields, word and count.
329- DataSetManager <Table > sinkManager = getDataset ("singleOutput" );
354+ DataSetManager <Table > sinkManager = getDataset (outputTable );
330355 Map <String , StructuredRecord > wordCounts =
331356 Maps .uniqueIndex (Sets .newHashSet (MockSink .readOutput (sinkManager )), new Function <StructuredRecord , String >() {
332357 @ Override
@@ -343,13 +368,28 @@ public String apply(StructuredRecord record) {
343368 }
344369
345370 @ Test
346- public void testScalaSparkSink () throws Exception {
347- Schema inputSchema = Schema .recordOf (
348- "input" ,
349- Schema .Field .of ("body" , Schema .nullableOf (Schema .of (Schema .Type .STRING )))
350- );
371+ public void testScalaSparkSinkRDD () throws Exception {
372+ File testFolder = TEMP_FOLDER .newFolder ("scalaSinkRDDOutput" );
373+ File outputFolder = new File (testFolder , "output" );
374+ StringWriter codeWriter = new StringWriter ();
375+ try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
376+ printer .println (
377+ "def sink(rdd: RDD[StructuredRecord], context:SparkExecutionPluginContext) : Unit = {" );
378+ printer .println (" val schema = context.getOutputSchema" );
379+ printer .println (" rdd" );
380+ printer .println (" .flatMap(_.get[String](\" body\" ).split(\" \\ \\ s+\" ))" );
381+ printer .println (" .map(s => (s, 1L))" );
382+ printer .println (" .reduceByKey(_ + _)" );
383+ printer .println (" .map(t => t._1 + \" \" + t._2)" );
384+ printer .println (" .saveAsTextFile(\" " + outputFolder .getAbsolutePath () + "\" )" );
385+ printer .println ("}" );
386+ }
387+ testWordCountSink (codeWriter .toString (), outputFolder );
388+ }
351389
352- File testFolder = TEMP_FOLDER .newFolder ("scalaSinkOutput" );
390+ @ Test
391+ public void testScalaSparkSinkDataFrame () throws Exception {
392+ File testFolder = TEMP_FOLDER .newFolder ("scalaSinkDataframeOutput" );
353393 File outputFolder = new File (testFolder , "output" );
354394 StringWriter codeWriter = new StringWriter ();
355395 try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
@@ -363,24 +403,34 @@ public void testScalaSparkSink() throws Exception {
363403 printer .println (" out.write.format(\" text\" ).save(\" " + outputFolder .getAbsolutePath () + "\" )" );
364404 printer .println ("}" );
365405 }
406+ testWordCountSink (codeWriter .toString (), outputFolder );
407+ }
408+
409+ private void testWordCountSink (String code , File outputFolder ) throws Exception {
410+ Schema inputSchema = Schema .recordOf (
411+ "input" ,
412+ Schema .Field .of ("body" , Schema .nullableOf (Schema .of (Schema .Type .STRING )))
413+ );
414+
415+ String inputTable = UUID .randomUUID ().toString ();
366416
367417 // Pipeline configuration
368418 ETLBatchConfig etlConfig = ETLBatchConfig .builder ("* * * * *" )
369- .addStage (new ETLStage ("source" , MockSource .getPlugin ("sinkInput" , inputSchema )))
419+ .addStage (new ETLStage ("source" , MockSource .getPlugin (inputTable , inputSchema )))
370420 .addStage (new ETLStage ("sink" , new ETLPlugin ("ScalaSparkSink" , SparkSink .PLUGIN_TYPE ,
371- ImmutableMap .of ("scalaCode" , codeWriter . toString () ))))
421+ ImmutableMap .of ("scalaCode" , code ))))
372422 .addConnection ("source" , "sink" )
373423 .build ();
374424
375425 // Deploy the pipeline
376426 ArtifactSummary artifactSummary = new ArtifactSummary (DATAPIPELINE_ARTIFACT_ID .getArtifact (),
377427 DATAPIPELINE_ARTIFACT_ID .getVersion ());
378428 AppRequest <ETLBatchConfig > appRequest = new AppRequest <>(artifactSummary , etlConfig );
379- ApplicationId appId = NamespaceId .DEFAULT .app ("ScalaSparkSinkApp" );
429+ ApplicationId appId = NamespaceId .DEFAULT .app (UUID . randomUUID (). toString () );
380430 ApplicationManager appManager = deployApplication (appId , appRequest );
381431
382432 // write records to source
383- DataSetManager <Table > inputManager = getDataset (NamespaceId .DEFAULT .dataset ("sinkInput" ));
433+ DataSetManager <Table > inputManager = getDataset (NamespaceId .DEFAULT .dataset (inputTable ));
384434 List <StructuredRecord > inputRecords = new ArrayList <>();
385435 for (int i = 0 ; i < 10 ; i ++) {
386436 inputRecords .add (StructuredRecord .builder (inputSchema ).set ("body" , "Line " + i ).build ());
0 commit comments