diff --git a/R/analyze.R b/R/analyze.R index 33a84f5..2fb865a 100644 --- a/R/analyze.R +++ b/R/analyze.R @@ -40,8 +40,11 @@ analyze <- function(preset, progress = NULL) { "correlation_positions" = function(...) { correlation(..., use_positions = TRUE) }, - "proximity" = proximity, - "neural" = neural + "neural" = neural, + "neural_positions" = function(...) { + neural(..., use_positions = TRUE) + }, + "proximity" = proximity ) results <- cached("analysis", preset, { diff --git a/R/neural.R b/R/neural.R index 730300c..2987ed1 100644 --- a/R/neural.R +++ b/R/neural.R @@ -1,14 +1,17 @@ # Find genes by training a neural network on reference position data. # # @param seed A seed to get reproducible results. -neural <- function(preset, progress = NULL, seed = 448077) { +neural <- function(preset, + use_positions = FALSE, + progress = NULL, + seed = 448077) { species_ids <- preset$species_ids gene_ids <- preset$gene_ids reference_gene_ids <- preset$reference_gene_ids cached( "neural", - c(species_ids, gene_ids, reference_gene_ids), + c(species_ids, gene_ids, reference_gene_ids, use_positions), { # nolint set.seed(seed) gene_count <- length(gene_ids) @@ -28,10 +31,17 @@ neural <- function(preset, progress = NULL, seed = 448077) { # Make a column containing positions for each species. for (species_id in species_ids) { - species_data <- distances[ - species == species_id, - .(gene, position) - ] + species_data <- if (use_positions) { + setnames(distances[ + species == species_id, + .(gene, position) + ], "position", "distance") + } else { + distances[ + species == species_id, + .(gene, distance) + ] + } # Only include species with at least 25% known values. @@ -48,11 +58,11 @@ neural <- function(preset, progress = NULL, seed = 448077) { # However, this will of course lessen the significance of # the results. - mean_position <- round(species_data[, mean(position)]) - data[is.na(position), position := mean_position] + mean_distance <- round(species_data[, mean(distance)]) + data[is.na(distance), distance := mean_distance] # Name the new column after the species. - setnames(data, "position", species_id) + setnames(data, "distance", species_id) } }